"""Vendor Cost Tracker Router - Track and analyze AI API spending""" from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel from typing import Optional from datetime import datetime, date, timedelta import uuid from collections import defaultdict router = APIRouter() # In-memory storage for cost entries and alerts cost_entries: list = [] budget_alerts: dict = {} # Comprehensive pricing data (per 1M tokens or per 1000 requests) PROVIDER_PRICING = { "openai": { "gpt-4o": {"input": 2.50, "output": 10.00, "unit": "1M tokens"}, "gpt-4o-mini": {"input": 0.15, "output": 0.60, "unit": "1M tokens"}, "gpt-4-turbo": {"input": 10.00, "output": 30.00, "unit": "1M tokens"}, "gpt-4": {"input": 30.00, "output": 60.00, "unit": "1M tokens"}, "gpt-3.5-turbo": {"input": 0.50, "output": 1.50, "unit": "1M tokens"}, "text-embedding-3-small": {"input": 0.02, "output": 0.0, "unit": "1M tokens"}, "text-embedding-3-large": {"input": 0.13, "output": 0.0, "unit": "1M tokens"}, "whisper": {"input": 0.006, "output": 0.0, "unit": "per minute"}, "dall-e-3": {"input": 0.04, "output": 0.0, "unit": "per image (1024x1024)"}, }, "anthropic": { "claude-opus-4": {"input": 15.00, "output": 75.00, "unit": "1M tokens"}, "claude-sonnet-4": {"input": 3.00, "output": 15.00, "unit": "1M tokens"}, "claude-3.5-sonnet": {"input": 3.00, "output": 15.00, "unit": "1M tokens"}, "claude-3.5-haiku": {"input": 0.80, "output": 4.00, "unit": "1M tokens"}, "claude-3-opus": {"input": 15.00, "output": 75.00, "unit": "1M tokens"}, "claude-3-sonnet": {"input": 3.00, "output": 15.00, "unit": "1M tokens"}, "claude-3-haiku": {"input": 0.25, "output": 1.25, "unit": "1M tokens"}, }, "google": { "gemini-2.0-flash": {"input": 0.10, "output": 0.40, "unit": "1M tokens"}, "gemini-1.5-pro": {"input": 1.25, "output": 5.00, "unit": "1M tokens"}, "gemini-1.5-flash": {"input": 0.075, "output": 0.30, "unit": "1M tokens"}, "gemini-1.0-pro": {"input": 0.50, "output": 1.50, "unit": "1M tokens"}, }, "aws": { "bedrock-claude-3-opus": {"input": 15.00, "output": 75.00, "unit": "1M tokens"}, "bedrock-claude-3-sonnet": {"input": 3.00, "output": 15.00, "unit": "1M tokens"}, "bedrock-claude-3-haiku": {"input": 0.25, "output": 1.25, "unit": "1M tokens"}, "bedrock-titan-text": {"input": 0.80, "output": 1.00, "unit": "1M tokens"}, "bedrock-titan-embeddings": {"input": 0.10, "output": 0.0, "unit": "1M tokens"}, }, "azure": { "azure-gpt-4o": {"input": 2.50, "output": 10.00, "unit": "1M tokens"}, "azure-gpt-4-turbo": {"input": 10.00, "output": 30.00, "unit": "1M tokens"}, "azure-gpt-4": {"input": 30.00, "output": 60.00, "unit": "1M tokens"}, "azure-gpt-35-turbo": {"input": 0.50, "output": 1.50, "unit": "1M tokens"}, }, "cohere": { "command-r-plus": {"input": 2.50, "output": 10.00, "unit": "1M tokens"}, "command-r": {"input": 0.15, "output": 0.60, "unit": "1M tokens"}, "embed-english-v3.0": {"input": 0.10, "output": 0.0, "unit": "1M tokens"}, }, "mistral": { "mistral-large": {"input": 2.00, "output": 6.00, "unit": "1M tokens"}, "mistral-small": {"input": 0.20, "output": 0.60, "unit": "1M tokens"}, "mistral-embed": {"input": 0.10, "output": 0.0, "unit": "1M tokens"}, } } class CostEntry(BaseModel): provider: str model: Optional[str] = None amount: float input_tokens: Optional[int] = None output_tokens: Optional[int] = None requests: Optional[int] = None project: Optional[str] = "default" description: Optional[str] = None entry_date: date class BudgetAlert(BaseModel): name: str provider: Optional[str] = None project: Optional[str] = None monthly_limit: float alert_threshold: float = 0.8 class CostSummary(BaseModel): total: float by_provider: dict by_project: dict by_model: dict daily_breakdown: list period_start: str period_end: str entry_count: int class TokenUsageEstimate(BaseModel): provider: str model: str input_tokens: int output_tokens: int @router.post("/log") async def log_cost(entry: CostEntry): """Log a cost entry""" entry_id = str(uuid.uuid4())[:8] cost_record = { "id": entry_id, "provider": entry.provider.lower(), "model": entry.model, "amount": entry.amount, "input_tokens": entry.input_tokens, "output_tokens": entry.output_tokens, "requests": entry.requests, "project": entry.project or "default", "description": entry.description, "entry_date": entry.entry_date.isoformat(), "created_at": datetime.now().isoformat() } cost_entries.append(cost_record) # Check budget alerts triggered_alerts = check_budget_alerts(entry.provider, entry.project) return { "message": "Cost logged successfully", "entry_id": entry_id, "entry": cost_record, "alerts_triggered": triggered_alerts } @router.post("/log-batch") async def log_costs_batch(entries: list[CostEntry]): """Log multiple cost entries at once""" results = [] for entry in entries: entry_id = str(uuid.uuid4())[:8] cost_record = { "id": entry_id, "provider": entry.provider.lower(), "model": entry.model, "amount": entry.amount, "input_tokens": entry.input_tokens, "output_tokens": entry.output_tokens, "requests": entry.requests, "project": entry.project or "default", "description": entry.description, "entry_date": entry.entry_date.isoformat(), "created_at": datetime.now().isoformat() } cost_entries.append(cost_record) results.append(cost_record) return { "message": f"Logged {len(results)} cost entries", "entries": results } @router.get("/summary") async def get_cost_summary( start_date: Optional[date] = None, end_date: Optional[date] = None, provider: Optional[str] = None, project: Optional[str] = None ): """Get cost summary for a period""" # Default to current month if not start_date: today = date.today() start_date = date(today.year, today.month, 1) if not end_date: end_date = date.today() # Filter entries filtered = [] for entry in cost_entries: entry_date = date.fromisoformat(entry["entry_date"]) if start_date <= entry_date <= end_date: if provider and entry["provider"] != provider.lower(): continue if project and entry["project"] != project: continue filtered.append(entry) # Aggregate total = sum(e["amount"] for e in filtered) by_provider = defaultdict(float) by_project = defaultdict(float) by_model = defaultdict(float) daily = defaultdict(float) for entry in filtered: by_provider[entry["provider"]] += entry["amount"] by_project[entry["project"]] += entry["amount"] if entry["model"]: by_model[f"{entry['provider']}/{entry['model']}"] += entry["amount"] daily[entry["entry_date"]] += entry["amount"] # Sort daily breakdown daily_breakdown = [ {"date": d, "amount": round(a, 2)} for d, a in sorted(daily.items()) ] return { "total": round(total, 2), "by_provider": {k: round(v, 2) for k, v in sorted(by_provider.items(), key=lambda x: -x[1])}, "by_project": {k: round(v, 2) for k, v in sorted(by_project.items(), key=lambda x: -x[1])}, "by_model": {k: round(v, 2) for k, v in sorted(by_model.items(), key=lambda x: -x[1])}, "daily_breakdown": daily_breakdown, "period_start": start_date.isoformat(), "period_end": end_date.isoformat(), "entry_count": len(filtered) } @router.get("/entries") async def get_cost_entries( limit: int = Query(100, le=1000), offset: int = 0, provider: Optional[str] = None, project: Optional[str] = None ): """Get individual cost entries with pagination""" filtered = cost_entries if provider: filtered = [e for e in filtered if e["provider"] == provider.lower()] if project: filtered = [e for e in filtered if e["project"] == project] # Sort by date descending filtered = sorted(filtered, key=lambda x: x["entry_date"], reverse=True) return { "entries": filtered[offset:offset + limit], "total": len(filtered), "limit": limit, "offset": offset } @router.delete("/entries/{entry_id}") async def delete_cost_entry(entry_id: str): """Delete a cost entry""" global cost_entries original_len = len(cost_entries) cost_entries = [e for e in cost_entries if e["id"] != entry_id] if len(cost_entries) == original_len: raise HTTPException(status_code=404, detail="Entry not found") return {"message": "Entry deleted", "entry_id": entry_id} @router.get("/forecast") async def forecast_costs( months: int = Query(3, ge=1, le=12), provider: Optional[str] = None, project: Optional[str] = None ): """Forecast future costs based on usage patterns""" if len(cost_entries) < 7: return { "message": "Need at least 7 days of data for forecasting", "forecast": [], "confidence": 0.0 } # Get last 30 days of data today = date.today() thirty_days_ago = today - timedelta(days=30) recent = [] for entry in cost_entries: entry_date = date.fromisoformat(entry["entry_date"]) if entry_date >= thirty_days_ago: if provider and entry["provider"] != provider.lower(): continue if project and entry["project"] != project: continue recent.append(entry) if not recent: return { "message": "No recent data for forecasting", "forecast": [], "confidence": 0.0 } # Calculate daily average daily_totals = defaultdict(float) for entry in recent: daily_totals[entry["entry_date"]] += entry["amount"] daily_avg = sum(daily_totals.values()) / max(len(daily_totals), 1) # Simple linear forecast forecast = [] for m in range(1, months + 1): # Days in forecast month forecast_date = today + timedelta(days=30 * m) days_in_month = 30 # Simplified # Add some variance for uncertainty base_forecast = daily_avg * days_in_month forecast.append({ "month": forecast_date.strftime("%Y-%m"), "predicted_cost": round(base_forecast, 2), "lower_bound": round(base_forecast * 0.8, 2), "upper_bound": round(base_forecast * 1.2, 2) }) # Confidence based on data points confidence = min(0.9, len(daily_totals) / 30) return { "daily_average": round(daily_avg, 2), "forecast": forecast, "confidence": round(confidence, 2), "based_on_days": len(daily_totals), "method": "linear_average" } @router.post("/alerts") async def set_budget_alert(alert: BudgetAlert): """Set budget alert thresholds""" alert_id = str(uuid.uuid4())[:8] alert_record = { "id": alert_id, "name": alert.name, "provider": alert.provider.lower() if alert.provider else None, "project": alert.project, "monthly_limit": alert.monthly_limit, "alert_threshold": alert.alert_threshold, "created_at": datetime.now().isoformat() } budget_alerts[alert_id] = alert_record return { "message": "Budget alert configured", "alert_id": alert_id, "alert": alert_record } @router.get("/alerts") async def get_budget_alerts(): """Get all budget alerts with current status""" today = date.today() month_start = date(today.year, today.month, 1) alerts_with_status = [] for alert in budget_alerts.values(): # Calculate current spend for this alert's scope filtered = cost_entries if alert["provider"]: filtered = [e for e in filtered if e["provider"] == alert["provider"]] if alert["project"]: filtered = [e for e in filtered if e["project"] == alert["project"]] # Filter to current month monthly = [ e for e in filtered if date.fromisoformat(e["entry_date"]) >= month_start ] current_spend = sum(e["amount"] for e in monthly) percent_used = (current_spend / alert["monthly_limit"] * 100) if alert["monthly_limit"] > 0 else 0 status = "ok" if percent_used >= 100: status = "exceeded" elif percent_used >= alert["alert_threshold"] * 100: status = "warning" alerts_with_status.append({ **alert, "current_spend": round(current_spend, 2), "percent_used": round(percent_used, 1), "remaining": round(max(0, alert["monthly_limit"] - current_spend), 2), "status": status }) return {"alerts": alerts_with_status} @router.delete("/alerts/{alert_id}") async def delete_budget_alert(alert_id: str): """Delete a budget alert""" if alert_id not in budget_alerts: raise HTTPException(status_code=404, detail="Alert not found") del budget_alerts[alert_id] return {"message": "Alert deleted", "alert_id": alert_id} def check_budget_alerts(provider: str, project: str) -> list: """Check if any budget alerts are triggered""" today = date.today() month_start = date(today.year, today.month, 1) triggered = [] for alert in budget_alerts.values(): # Check if alert applies if alert["provider"] and alert["provider"] != provider.lower(): continue if alert["project"] and alert["project"] != project: continue # Calculate current spend filtered = cost_entries if alert["provider"]: filtered = [e for e in filtered if e["provider"] == alert["provider"]] if alert["project"]: filtered = [e for e in filtered if e["project"] == alert["project"]] monthly = [ e for e in filtered if date.fromisoformat(e["entry_date"]) >= month_start ] current_spend = sum(e["amount"] for e in monthly) threshold_amount = alert["monthly_limit"] * alert["alert_threshold"] if current_spend >= threshold_amount: triggered.append({ "alert_id": alert["id"], "alert_name": alert["name"], "current_spend": round(current_spend, 2), "limit": alert["monthly_limit"], "severity": "exceeded" if current_spend >= alert["monthly_limit"] else "warning" }) return triggered @router.post("/estimate") async def estimate_cost(usage: TokenUsageEstimate): """Estimate cost for given token usage""" provider = usage.provider.lower() model = usage.model.lower() if provider not in PROVIDER_PRICING: raise HTTPException(status_code=400, detail=f"Unknown provider: {provider}") provider_models = PROVIDER_PRICING[provider] # Find matching model (fuzzy match) matched_model = None for m in provider_models: if m.lower() == model or model in m.lower(): matched_model = m break if not matched_model: return { "error": f"Model '{model}' not found for provider '{provider}'", "available_models": list(provider_models.keys()) } pricing = provider_models[matched_model] # Calculate cost (pricing is per 1M tokens) input_cost = (usage.input_tokens / 1_000_000) * pricing["input"] output_cost = (usage.output_tokens / 1_000_000) * pricing["output"] total_cost = input_cost + output_cost return { "provider": provider, "model": matched_model, "input_tokens": usage.input_tokens, "output_tokens": usage.output_tokens, "input_cost": round(input_cost, 6), "output_cost": round(output_cost, 6), "total_cost": round(total_cost, 6), "pricing": pricing } @router.get("/providers") async def list_providers(): """List supported providers with current pricing""" providers = [] for provider, models in PROVIDER_PRICING.items(): provider_info = { "name": provider, "models": [] } for model, pricing in models.items(): provider_info["models"].append({ "name": model, "input_price": pricing["input"], "output_price": pricing["output"], "unit": pricing["unit"] }) providers.append(provider_info) return {"providers": providers} @router.get("/compare-providers") async def compare_providers( input_tokens: int = Query(1000000), output_tokens: int = Query(500000) ): """Compare costs across providers for the same usage""" comparisons = [] for provider, models in PROVIDER_PRICING.items(): for model, pricing in models.items(): if pricing["unit"] != "1M tokens": continue # Skip non-token based pricing input_cost = (input_tokens / 1_000_000) * pricing["input"] output_cost = (output_tokens / 1_000_000) * pricing["output"] total = input_cost + output_cost comparisons.append({ "provider": provider, "model": model, "input_cost": round(input_cost, 4), "output_cost": round(output_cost, 4), "total_cost": round(total, 4) }) # Sort by total cost comparisons.sort(key=lambda x: x["total_cost"]) cheapest = comparisons[0] if comparisons else None most_expensive = comparisons[-1] if comparisons else None return { "input_tokens": input_tokens, "output_tokens": output_tokens, "comparisons": comparisons, "cheapest": cheapest, "most_expensive": most_expensive, "savings_potential": round(most_expensive["total_cost"] - cheapest["total_cost"], 4) if cheapest and most_expensive else 0 } @router.get("/stats") async def get_cost_stats(): """Get overall cost statistics""" if not cost_entries: return { "message": "No cost data available", "total_entries": 0 } today = date.today() this_month_start = date(today.year, today.month, 1) last_month_start = date(today.year, today.month - 1, 1) if today.month > 1 else date(today.year - 1, 12, 1) # This month this_month = [ e for e in cost_entries if date.fromisoformat(e["entry_date"]) >= this_month_start ] this_month_total = sum(e["amount"] for e in this_month) # Last month last_month = [ e for e in cost_entries if last_month_start <= date.fromisoformat(e["entry_date"]) < this_month_start ] last_month_total = sum(e["amount"] for e in last_month) # Calculate change if last_month_total > 0: month_change = ((this_month_total - last_month_total) / last_month_total) * 100 else: month_change = 100 if this_month_total > 0 else 0 # All time stats all_time_total = sum(e["amount"] for e in cost_entries) unique_providers = len(set(e["provider"] for e in cost_entries)) unique_projects = len(set(e["project"] for e in cost_entries)) # Date range dates = [date.fromisoformat(e["entry_date"]) for e in cost_entries] return { "this_month_total": round(this_month_total, 2), "last_month_total": round(last_month_total, 2), "month_over_month_change": round(month_change, 1), "all_time_total": round(all_time_total, 2), "total_entries": len(cost_entries), "unique_providers": unique_providers, "unique_projects": unique_projects, "date_range": { "earliest": min(dates).isoformat() if dates else None, "latest": max(dates).isoformat() if dates else None } }