608 lines
20 KiB
Python
608 lines
20 KiB
Python
"""Vendor Cost Tracker Router - Track and analyze AI API spending"""
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
from datetime import datetime, date, timedelta
|
|
import uuid
|
|
from collections import defaultdict
|
|
|
|
router = APIRouter()
|
|
|
|
# In-memory storage for cost entries and alerts
|
|
cost_entries: list = []
|
|
budget_alerts: dict = {}
|
|
|
|
# Comprehensive pricing data (per 1M tokens or per 1000 requests)
|
|
PROVIDER_PRICING = {
|
|
"openai": {
|
|
"gpt-4o": {"input": 2.50, "output": 10.00, "unit": "1M tokens"},
|
|
"gpt-4o-mini": {"input": 0.15, "output": 0.60, "unit": "1M tokens"},
|
|
"gpt-4-turbo": {"input": 10.00, "output": 30.00, "unit": "1M tokens"},
|
|
"gpt-4": {"input": 30.00, "output": 60.00, "unit": "1M tokens"},
|
|
"gpt-3.5-turbo": {"input": 0.50, "output": 1.50, "unit": "1M tokens"},
|
|
"text-embedding-3-small": {"input": 0.02, "output": 0.0, "unit": "1M tokens"},
|
|
"text-embedding-3-large": {"input": 0.13, "output": 0.0, "unit": "1M tokens"},
|
|
"whisper": {"input": 0.006, "output": 0.0, "unit": "per minute"},
|
|
"dall-e-3": {"input": 0.04, "output": 0.0, "unit": "per image (1024x1024)"},
|
|
},
|
|
"anthropic": {
|
|
"claude-opus-4": {"input": 15.00, "output": 75.00, "unit": "1M tokens"},
|
|
"claude-sonnet-4": {"input": 3.00, "output": 15.00, "unit": "1M tokens"},
|
|
"claude-3.5-sonnet": {"input": 3.00, "output": 15.00, "unit": "1M tokens"},
|
|
"claude-3.5-haiku": {"input": 0.80, "output": 4.00, "unit": "1M tokens"},
|
|
"claude-3-opus": {"input": 15.00, "output": 75.00, "unit": "1M tokens"},
|
|
"claude-3-sonnet": {"input": 3.00, "output": 15.00, "unit": "1M tokens"},
|
|
"claude-3-haiku": {"input": 0.25, "output": 1.25, "unit": "1M tokens"},
|
|
},
|
|
"google": {
|
|
"gemini-2.0-flash": {"input": 0.10, "output": 0.40, "unit": "1M tokens"},
|
|
"gemini-1.5-pro": {"input": 1.25, "output": 5.00, "unit": "1M tokens"},
|
|
"gemini-1.5-flash": {"input": 0.075, "output": 0.30, "unit": "1M tokens"},
|
|
"gemini-1.0-pro": {"input": 0.50, "output": 1.50, "unit": "1M tokens"},
|
|
},
|
|
"aws": {
|
|
"bedrock-claude-3-opus": {"input": 15.00, "output": 75.00, "unit": "1M tokens"},
|
|
"bedrock-claude-3-sonnet": {"input": 3.00, "output": 15.00, "unit": "1M tokens"},
|
|
"bedrock-claude-3-haiku": {"input": 0.25, "output": 1.25, "unit": "1M tokens"},
|
|
"bedrock-titan-text": {"input": 0.80, "output": 1.00, "unit": "1M tokens"},
|
|
"bedrock-titan-embeddings": {"input": 0.10, "output": 0.0, "unit": "1M tokens"},
|
|
},
|
|
"azure": {
|
|
"azure-gpt-4o": {"input": 2.50, "output": 10.00, "unit": "1M tokens"},
|
|
"azure-gpt-4-turbo": {"input": 10.00, "output": 30.00, "unit": "1M tokens"},
|
|
"azure-gpt-4": {"input": 30.00, "output": 60.00, "unit": "1M tokens"},
|
|
"azure-gpt-35-turbo": {"input": 0.50, "output": 1.50, "unit": "1M tokens"},
|
|
},
|
|
"cohere": {
|
|
"command-r-plus": {"input": 2.50, "output": 10.00, "unit": "1M tokens"},
|
|
"command-r": {"input": 0.15, "output": 0.60, "unit": "1M tokens"},
|
|
"embed-english-v3.0": {"input": 0.10, "output": 0.0, "unit": "1M tokens"},
|
|
},
|
|
"mistral": {
|
|
"mistral-large": {"input": 2.00, "output": 6.00, "unit": "1M tokens"},
|
|
"mistral-small": {"input": 0.20, "output": 0.60, "unit": "1M tokens"},
|
|
"mistral-embed": {"input": 0.10, "output": 0.0, "unit": "1M tokens"},
|
|
}
|
|
}
|
|
|
|
|
|
class CostEntry(BaseModel):
|
|
provider: str
|
|
model: Optional[str] = None
|
|
amount: float
|
|
input_tokens: Optional[int] = None
|
|
output_tokens: Optional[int] = None
|
|
requests: Optional[int] = None
|
|
project: Optional[str] = "default"
|
|
description: Optional[str] = None
|
|
entry_date: date
|
|
|
|
|
|
class BudgetAlert(BaseModel):
|
|
name: str
|
|
provider: Optional[str] = None
|
|
project: Optional[str] = None
|
|
monthly_limit: float
|
|
alert_threshold: float = 0.8
|
|
|
|
|
|
class CostSummary(BaseModel):
|
|
total: float
|
|
by_provider: dict
|
|
by_project: dict
|
|
by_model: dict
|
|
daily_breakdown: list
|
|
period_start: str
|
|
period_end: str
|
|
entry_count: int
|
|
|
|
|
|
class TokenUsageEstimate(BaseModel):
|
|
provider: str
|
|
model: str
|
|
input_tokens: int
|
|
output_tokens: int
|
|
|
|
|
|
@router.post("/log")
|
|
async def log_cost(entry: CostEntry):
|
|
"""Log a cost entry"""
|
|
entry_id = str(uuid.uuid4())[:8]
|
|
|
|
cost_record = {
|
|
"id": entry_id,
|
|
"provider": entry.provider.lower(),
|
|
"model": entry.model,
|
|
"amount": entry.amount,
|
|
"input_tokens": entry.input_tokens,
|
|
"output_tokens": entry.output_tokens,
|
|
"requests": entry.requests,
|
|
"project": entry.project or "default",
|
|
"description": entry.description,
|
|
"entry_date": entry.entry_date.isoformat(),
|
|
"created_at": datetime.now().isoformat()
|
|
}
|
|
|
|
cost_entries.append(cost_record)
|
|
|
|
# Check budget alerts
|
|
triggered_alerts = check_budget_alerts(entry.provider, entry.project)
|
|
|
|
return {
|
|
"message": "Cost logged successfully",
|
|
"entry_id": entry_id,
|
|
"entry": cost_record,
|
|
"alerts_triggered": triggered_alerts
|
|
}
|
|
|
|
|
|
@router.post("/log-batch")
|
|
async def log_costs_batch(entries: list[CostEntry]):
|
|
"""Log multiple cost entries at once"""
|
|
results = []
|
|
for entry in entries:
|
|
entry_id = str(uuid.uuid4())[:8]
|
|
cost_record = {
|
|
"id": entry_id,
|
|
"provider": entry.provider.lower(),
|
|
"model": entry.model,
|
|
"amount": entry.amount,
|
|
"input_tokens": entry.input_tokens,
|
|
"output_tokens": entry.output_tokens,
|
|
"requests": entry.requests,
|
|
"project": entry.project or "default",
|
|
"description": entry.description,
|
|
"entry_date": entry.entry_date.isoformat(),
|
|
"created_at": datetime.now().isoformat()
|
|
}
|
|
cost_entries.append(cost_record)
|
|
results.append(cost_record)
|
|
|
|
return {
|
|
"message": f"Logged {len(results)} cost entries",
|
|
"entries": results
|
|
}
|
|
|
|
|
|
@router.get("/summary")
|
|
async def get_cost_summary(
|
|
start_date: Optional[date] = None,
|
|
end_date: Optional[date] = None,
|
|
provider: Optional[str] = None,
|
|
project: Optional[str] = None
|
|
):
|
|
"""Get cost summary for a period"""
|
|
# Default to current month
|
|
if not start_date:
|
|
today = date.today()
|
|
start_date = date(today.year, today.month, 1)
|
|
if not end_date:
|
|
end_date = date.today()
|
|
|
|
# Filter entries
|
|
filtered = []
|
|
for entry in cost_entries:
|
|
entry_date = date.fromisoformat(entry["entry_date"])
|
|
if start_date <= entry_date <= end_date:
|
|
if provider and entry["provider"] != provider.lower():
|
|
continue
|
|
if project and entry["project"] != project:
|
|
continue
|
|
filtered.append(entry)
|
|
|
|
# Aggregate
|
|
total = sum(e["amount"] for e in filtered)
|
|
by_provider = defaultdict(float)
|
|
by_project = defaultdict(float)
|
|
by_model = defaultdict(float)
|
|
daily = defaultdict(float)
|
|
|
|
for entry in filtered:
|
|
by_provider[entry["provider"]] += entry["amount"]
|
|
by_project[entry["project"]] += entry["amount"]
|
|
if entry["model"]:
|
|
by_model[f"{entry['provider']}/{entry['model']}"] += entry["amount"]
|
|
daily[entry["entry_date"]] += entry["amount"]
|
|
|
|
# Sort daily breakdown
|
|
daily_breakdown = [
|
|
{"date": d, "amount": round(a, 2)}
|
|
for d, a in sorted(daily.items())
|
|
]
|
|
|
|
return {
|
|
"total": round(total, 2),
|
|
"by_provider": {k: round(v, 2) for k, v in sorted(by_provider.items(), key=lambda x: -x[1])},
|
|
"by_project": {k: round(v, 2) for k, v in sorted(by_project.items(), key=lambda x: -x[1])},
|
|
"by_model": {k: round(v, 2) for k, v in sorted(by_model.items(), key=lambda x: -x[1])},
|
|
"daily_breakdown": daily_breakdown,
|
|
"period_start": start_date.isoformat(),
|
|
"period_end": end_date.isoformat(),
|
|
"entry_count": len(filtered)
|
|
}
|
|
|
|
|
|
@router.get("/entries")
|
|
async def get_cost_entries(
|
|
limit: int = Query(100, le=1000),
|
|
offset: int = 0,
|
|
provider: Optional[str] = None,
|
|
project: Optional[str] = None
|
|
):
|
|
"""Get individual cost entries with pagination"""
|
|
filtered = cost_entries
|
|
|
|
if provider:
|
|
filtered = [e for e in filtered if e["provider"] == provider.lower()]
|
|
if project:
|
|
filtered = [e for e in filtered if e["project"] == project]
|
|
|
|
# Sort by date descending
|
|
filtered = sorted(filtered, key=lambda x: x["entry_date"], reverse=True)
|
|
|
|
return {
|
|
"entries": filtered[offset:offset + limit],
|
|
"total": len(filtered),
|
|
"limit": limit,
|
|
"offset": offset
|
|
}
|
|
|
|
|
|
@router.delete("/entries/{entry_id}")
|
|
async def delete_cost_entry(entry_id: str):
|
|
"""Delete a cost entry"""
|
|
global cost_entries
|
|
original_len = len(cost_entries)
|
|
cost_entries = [e for e in cost_entries if e["id"] != entry_id]
|
|
|
|
if len(cost_entries) == original_len:
|
|
raise HTTPException(status_code=404, detail="Entry not found")
|
|
|
|
return {"message": "Entry deleted", "entry_id": entry_id}
|
|
|
|
|
|
@router.get("/forecast")
|
|
async def forecast_costs(
|
|
months: int = Query(3, ge=1, le=12),
|
|
provider: Optional[str] = None,
|
|
project: Optional[str] = None
|
|
):
|
|
"""Forecast future costs based on usage patterns"""
|
|
if len(cost_entries) < 7:
|
|
return {
|
|
"message": "Need at least 7 days of data for forecasting",
|
|
"forecast": [],
|
|
"confidence": 0.0
|
|
}
|
|
|
|
# Get last 30 days of data
|
|
today = date.today()
|
|
thirty_days_ago = today - timedelta(days=30)
|
|
|
|
recent = []
|
|
for entry in cost_entries:
|
|
entry_date = date.fromisoformat(entry["entry_date"])
|
|
if entry_date >= thirty_days_ago:
|
|
if provider and entry["provider"] != provider.lower():
|
|
continue
|
|
if project and entry["project"] != project:
|
|
continue
|
|
recent.append(entry)
|
|
|
|
if not recent:
|
|
return {
|
|
"message": "No recent data for forecasting",
|
|
"forecast": [],
|
|
"confidence": 0.0
|
|
}
|
|
|
|
# Calculate daily average
|
|
daily_totals = defaultdict(float)
|
|
for entry in recent:
|
|
daily_totals[entry["entry_date"]] += entry["amount"]
|
|
|
|
daily_avg = sum(daily_totals.values()) / max(len(daily_totals), 1)
|
|
|
|
# Simple linear forecast
|
|
forecast = []
|
|
for m in range(1, months + 1):
|
|
# Days in forecast month
|
|
forecast_date = today + timedelta(days=30 * m)
|
|
days_in_month = 30 # Simplified
|
|
|
|
# Add some variance for uncertainty
|
|
base_forecast = daily_avg * days_in_month
|
|
|
|
forecast.append({
|
|
"month": forecast_date.strftime("%Y-%m"),
|
|
"predicted_cost": round(base_forecast, 2),
|
|
"lower_bound": round(base_forecast * 0.8, 2),
|
|
"upper_bound": round(base_forecast * 1.2, 2)
|
|
})
|
|
|
|
# Confidence based on data points
|
|
confidence = min(0.9, len(daily_totals) / 30)
|
|
|
|
return {
|
|
"daily_average": round(daily_avg, 2),
|
|
"forecast": forecast,
|
|
"confidence": round(confidence, 2),
|
|
"based_on_days": len(daily_totals),
|
|
"method": "linear_average"
|
|
}
|
|
|
|
|
|
@router.post("/alerts")
|
|
async def set_budget_alert(alert: BudgetAlert):
|
|
"""Set budget alert thresholds"""
|
|
alert_id = str(uuid.uuid4())[:8]
|
|
|
|
alert_record = {
|
|
"id": alert_id,
|
|
"name": alert.name,
|
|
"provider": alert.provider.lower() if alert.provider else None,
|
|
"project": alert.project,
|
|
"monthly_limit": alert.monthly_limit,
|
|
"alert_threshold": alert.alert_threshold,
|
|
"created_at": datetime.now().isoformat()
|
|
}
|
|
|
|
budget_alerts[alert_id] = alert_record
|
|
|
|
return {
|
|
"message": "Budget alert configured",
|
|
"alert_id": alert_id,
|
|
"alert": alert_record
|
|
}
|
|
|
|
|
|
@router.get("/alerts")
|
|
async def get_budget_alerts():
|
|
"""Get all budget alerts with current status"""
|
|
today = date.today()
|
|
month_start = date(today.year, today.month, 1)
|
|
|
|
alerts_with_status = []
|
|
for alert in budget_alerts.values():
|
|
# Calculate current spend for this alert's scope
|
|
filtered = cost_entries
|
|
if alert["provider"]:
|
|
filtered = [e for e in filtered if e["provider"] == alert["provider"]]
|
|
if alert["project"]:
|
|
filtered = [e for e in filtered if e["project"] == alert["project"]]
|
|
|
|
# Filter to current month
|
|
monthly = [
|
|
e for e in filtered
|
|
if date.fromisoformat(e["entry_date"]) >= month_start
|
|
]
|
|
|
|
current_spend = sum(e["amount"] for e in monthly)
|
|
percent_used = (current_spend / alert["monthly_limit"] * 100) if alert["monthly_limit"] > 0 else 0
|
|
|
|
status = "ok"
|
|
if percent_used >= 100:
|
|
status = "exceeded"
|
|
elif percent_used >= alert["alert_threshold"] * 100:
|
|
status = "warning"
|
|
|
|
alerts_with_status.append({
|
|
**alert,
|
|
"current_spend": round(current_spend, 2),
|
|
"percent_used": round(percent_used, 1),
|
|
"remaining": round(max(0, alert["monthly_limit"] - current_spend), 2),
|
|
"status": status
|
|
})
|
|
|
|
return {"alerts": alerts_with_status}
|
|
|
|
|
|
@router.delete("/alerts/{alert_id}")
|
|
async def delete_budget_alert(alert_id: str):
|
|
"""Delete a budget alert"""
|
|
if alert_id not in budget_alerts:
|
|
raise HTTPException(status_code=404, detail="Alert not found")
|
|
|
|
del budget_alerts[alert_id]
|
|
return {"message": "Alert deleted", "alert_id": alert_id}
|
|
|
|
|
|
def check_budget_alerts(provider: str, project: str) -> list:
|
|
"""Check if any budget alerts are triggered"""
|
|
today = date.today()
|
|
month_start = date(today.year, today.month, 1)
|
|
|
|
triggered = []
|
|
for alert in budget_alerts.values():
|
|
# Check if alert applies
|
|
if alert["provider"] and alert["provider"] != provider.lower():
|
|
continue
|
|
if alert["project"] and alert["project"] != project:
|
|
continue
|
|
|
|
# Calculate current spend
|
|
filtered = cost_entries
|
|
if alert["provider"]:
|
|
filtered = [e for e in filtered if e["provider"] == alert["provider"]]
|
|
if alert["project"]:
|
|
filtered = [e for e in filtered if e["project"] == alert["project"]]
|
|
|
|
monthly = [
|
|
e for e in filtered
|
|
if date.fromisoformat(e["entry_date"]) >= month_start
|
|
]
|
|
|
|
current_spend = sum(e["amount"] for e in monthly)
|
|
threshold_amount = alert["monthly_limit"] * alert["alert_threshold"]
|
|
|
|
if current_spend >= threshold_amount:
|
|
triggered.append({
|
|
"alert_id": alert["id"],
|
|
"alert_name": alert["name"],
|
|
"current_spend": round(current_spend, 2),
|
|
"limit": alert["monthly_limit"],
|
|
"severity": "exceeded" if current_spend >= alert["monthly_limit"] else "warning"
|
|
})
|
|
|
|
return triggered
|
|
|
|
|
|
@router.post("/estimate")
|
|
async def estimate_cost(usage: TokenUsageEstimate):
|
|
"""Estimate cost for given token usage"""
|
|
provider = usage.provider.lower()
|
|
model = usage.model.lower()
|
|
|
|
if provider not in PROVIDER_PRICING:
|
|
raise HTTPException(status_code=400, detail=f"Unknown provider: {provider}")
|
|
|
|
provider_models = PROVIDER_PRICING[provider]
|
|
|
|
# Find matching model (fuzzy match)
|
|
matched_model = None
|
|
for m in provider_models:
|
|
if m.lower() == model or model in m.lower():
|
|
matched_model = m
|
|
break
|
|
|
|
if not matched_model:
|
|
return {
|
|
"error": f"Model '{model}' not found for provider '{provider}'",
|
|
"available_models": list(provider_models.keys())
|
|
}
|
|
|
|
pricing = provider_models[matched_model]
|
|
|
|
# Calculate cost (pricing is per 1M tokens)
|
|
input_cost = (usage.input_tokens / 1_000_000) * pricing["input"]
|
|
output_cost = (usage.output_tokens / 1_000_000) * pricing["output"]
|
|
total_cost = input_cost + output_cost
|
|
|
|
return {
|
|
"provider": provider,
|
|
"model": matched_model,
|
|
"input_tokens": usage.input_tokens,
|
|
"output_tokens": usage.output_tokens,
|
|
"input_cost": round(input_cost, 6),
|
|
"output_cost": round(output_cost, 6),
|
|
"total_cost": round(total_cost, 6),
|
|
"pricing": pricing
|
|
}
|
|
|
|
|
|
@router.get("/providers")
|
|
async def list_providers():
|
|
"""List supported providers with current pricing"""
|
|
providers = []
|
|
for provider, models in PROVIDER_PRICING.items():
|
|
provider_info = {
|
|
"name": provider,
|
|
"models": []
|
|
}
|
|
for model, pricing in models.items():
|
|
provider_info["models"].append({
|
|
"name": model,
|
|
"input_price": pricing["input"],
|
|
"output_price": pricing["output"],
|
|
"unit": pricing["unit"]
|
|
})
|
|
providers.append(provider_info)
|
|
|
|
return {"providers": providers}
|
|
|
|
|
|
@router.get("/compare-providers")
|
|
async def compare_providers(
|
|
input_tokens: int = Query(1000000),
|
|
output_tokens: int = Query(500000)
|
|
):
|
|
"""Compare costs across providers for the same usage"""
|
|
comparisons = []
|
|
|
|
for provider, models in PROVIDER_PRICING.items():
|
|
for model, pricing in models.items():
|
|
if pricing["unit"] != "1M tokens":
|
|
continue # Skip non-token based pricing
|
|
|
|
input_cost = (input_tokens / 1_000_000) * pricing["input"]
|
|
output_cost = (output_tokens / 1_000_000) * pricing["output"]
|
|
total = input_cost + output_cost
|
|
|
|
comparisons.append({
|
|
"provider": provider,
|
|
"model": model,
|
|
"input_cost": round(input_cost, 4),
|
|
"output_cost": round(output_cost, 4),
|
|
"total_cost": round(total, 4)
|
|
})
|
|
|
|
# Sort by total cost
|
|
comparisons.sort(key=lambda x: x["total_cost"])
|
|
|
|
cheapest = comparisons[0] if comparisons else None
|
|
most_expensive = comparisons[-1] if comparisons else None
|
|
|
|
return {
|
|
"input_tokens": input_tokens,
|
|
"output_tokens": output_tokens,
|
|
"comparisons": comparisons,
|
|
"cheapest": cheapest,
|
|
"most_expensive": most_expensive,
|
|
"savings_potential": round(most_expensive["total_cost"] - cheapest["total_cost"], 4) if cheapest and most_expensive else 0
|
|
}
|
|
|
|
|
|
@router.get("/stats")
|
|
async def get_cost_stats():
|
|
"""Get overall cost statistics"""
|
|
if not cost_entries:
|
|
return {
|
|
"message": "No cost data available",
|
|
"total_entries": 0
|
|
}
|
|
|
|
today = date.today()
|
|
this_month_start = date(today.year, today.month, 1)
|
|
last_month_start = date(today.year, today.month - 1, 1) if today.month > 1 else date(today.year - 1, 12, 1)
|
|
|
|
# This month
|
|
this_month = [
|
|
e for e in cost_entries
|
|
if date.fromisoformat(e["entry_date"]) >= this_month_start
|
|
]
|
|
this_month_total = sum(e["amount"] for e in this_month)
|
|
|
|
# Last month
|
|
last_month = [
|
|
e for e in cost_entries
|
|
if last_month_start <= date.fromisoformat(e["entry_date"]) < this_month_start
|
|
]
|
|
last_month_total = sum(e["amount"] for e in last_month)
|
|
|
|
# Calculate change
|
|
if last_month_total > 0:
|
|
month_change = ((this_month_total - last_month_total) / last_month_total) * 100
|
|
else:
|
|
month_change = 100 if this_month_total > 0 else 0
|
|
|
|
# All time stats
|
|
all_time_total = sum(e["amount"] for e in cost_entries)
|
|
unique_providers = len(set(e["provider"] for e in cost_entries))
|
|
unique_projects = len(set(e["project"] for e in cost_entries))
|
|
|
|
# Date range
|
|
dates = [date.fromisoformat(e["entry_date"]) for e in cost_entries]
|
|
|
|
return {
|
|
"this_month_total": round(this_month_total, 2),
|
|
"last_month_total": round(last_month_total, 2),
|
|
"month_over_month_change": round(month_change, 1),
|
|
"all_time_total": round(all_time_total, 2),
|
|
"total_entries": len(cost_entries),
|
|
"unique_providers": unique_providers,
|
|
"unique_projects": unique_projects,
|
|
"date_range": {
|
|
"earliest": min(dates).isoformat() if dates else None,
|
|
"latest": max(dates).isoformat() if dates else None
|
|
}
|
|
}
|