236 lines
6.9 KiB
Python
236 lines
6.9 KiB
Python
"""Inference Estimator Router"""
|
|
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
from pathlib import Path
|
|
import json
|
|
|
|
router = APIRouter()
|
|
|
|
# Path to pricing config
|
|
CONFIG_PATH = Path(__file__).parent.parent / "config" / "pricing.json"
|
|
|
|
|
|
def load_pricing() -> dict:
|
|
"""Load pricing from config file"""
|
|
if not CONFIG_PATH.exists():
|
|
raise HTTPException(status_code=500, detail="Pricing config not found")
|
|
|
|
with open(CONFIG_PATH, "r") as f:
|
|
config = json.load(f)
|
|
|
|
# Merge user overrides with base pricing
|
|
models = config.get("models", {})
|
|
overrides = config.get("user_overrides", {})
|
|
|
|
for model_name, override_data in overrides.items():
|
|
if model_name in models:
|
|
models[model_name].update(override_data)
|
|
else:
|
|
models[model_name] = override_data
|
|
|
|
return {
|
|
"models": models,
|
|
"last_updated": config.get("last_updated", "unknown"),
|
|
"sources": config.get("sources", {}),
|
|
"currency": config.get("currency", "USD"),
|
|
}
|
|
|
|
|
|
def save_pricing(config: dict):
|
|
"""Save pricing config to file"""
|
|
with open(CONFIG_PATH, "w") as f:
|
|
json.dump(config, f, indent=2)
|
|
|
|
|
|
class EstimateRequest(BaseModel):
|
|
model: str
|
|
input_tokens_per_request: int = 500
|
|
output_tokens_per_request: int = 500
|
|
requests_per_day: int = 1000
|
|
days_per_month: int = 30
|
|
|
|
|
|
class EstimateResponse(BaseModel):
|
|
model: str
|
|
daily_cost: float
|
|
monthly_cost: float
|
|
yearly_cost: float
|
|
total_input_tokens: int
|
|
total_output_tokens: int
|
|
breakdown: dict
|
|
|
|
|
|
class CompareRequest(BaseModel):
|
|
models: list[str]
|
|
input_tokens_per_request: int = 500
|
|
output_tokens_per_request: int = 500
|
|
requests_per_day: int = 1000
|
|
days_per_month: int = 30
|
|
|
|
|
|
class PriceOverride(BaseModel):
|
|
model: str
|
|
input: float
|
|
output: float
|
|
description: Optional[str] = None
|
|
|
|
|
|
@router.post("/calculate", response_model=EstimateResponse)
|
|
async def calculate_estimate(request: EstimateRequest):
|
|
"""Calculate cost estimate for a model"""
|
|
pricing_data = load_pricing()
|
|
models = pricing_data["models"]
|
|
|
|
if request.model not in models:
|
|
return EstimateResponse(
|
|
model=request.model,
|
|
daily_cost=0.0,
|
|
monthly_cost=0.0,
|
|
yearly_cost=0.0,
|
|
total_input_tokens=0,
|
|
total_output_tokens=0,
|
|
breakdown={"error": f"Unknown model: {request.model}"}
|
|
)
|
|
|
|
pricing = models[request.model]
|
|
|
|
daily_input_tokens = request.input_tokens_per_request * request.requests_per_day
|
|
daily_output_tokens = request.output_tokens_per_request * request.requests_per_day
|
|
|
|
daily_input_cost = (daily_input_tokens / 1_000_000) * pricing["input"]
|
|
daily_output_cost = (daily_output_tokens / 1_000_000) * pricing["output"]
|
|
daily_cost = daily_input_cost + daily_output_cost
|
|
|
|
monthly_cost = daily_cost * request.days_per_month
|
|
yearly_cost = monthly_cost * 12
|
|
|
|
return EstimateResponse(
|
|
model=request.model,
|
|
daily_cost=round(daily_cost, 2),
|
|
monthly_cost=round(monthly_cost, 2),
|
|
yearly_cost=round(yearly_cost, 2),
|
|
total_input_tokens=daily_input_tokens * request.days_per_month,
|
|
total_output_tokens=daily_output_tokens * request.days_per_month,
|
|
breakdown={
|
|
"input_cost_per_day": round(daily_input_cost, 2),
|
|
"output_cost_per_day": round(daily_output_cost, 2),
|
|
"input_price_per_1m": pricing["input"],
|
|
"output_price_per_1m": pricing["output"],
|
|
}
|
|
)
|
|
|
|
|
|
@router.post("/compare")
|
|
async def compare_models(request: CompareRequest):
|
|
"""Compare costs across multiple models"""
|
|
pricing_data = load_pricing()
|
|
models = pricing_data["models"]
|
|
|
|
results = []
|
|
for model in request.models:
|
|
if model in models:
|
|
estimate_req = EstimateRequest(
|
|
model=model,
|
|
input_tokens_per_request=request.input_tokens_per_request,
|
|
output_tokens_per_request=request.output_tokens_per_request,
|
|
requests_per_day=request.requests_per_day,
|
|
days_per_month=request.days_per_month,
|
|
)
|
|
result = await calculate_estimate(estimate_req)
|
|
results.append(result)
|
|
|
|
results.sort(key=lambda x: x.monthly_cost)
|
|
|
|
return {
|
|
"comparison": results,
|
|
"cheapest": results[0].model if results else None,
|
|
"most_expensive": results[-1].model if results else None,
|
|
}
|
|
|
|
|
|
@router.get("/models")
|
|
async def list_models():
|
|
"""List available models with pricing"""
|
|
pricing_data = load_pricing()
|
|
models = pricing_data["models"]
|
|
|
|
return {
|
|
"last_updated": pricing_data["last_updated"],
|
|
"currency": pricing_data["currency"],
|
|
"sources": pricing_data["sources"],
|
|
"models": [
|
|
{"name": name, **data}
|
|
for name, data in models.items()
|
|
]
|
|
}
|
|
|
|
|
|
@router.get("/pricing-config")
|
|
async def get_pricing_config():
|
|
"""Get full pricing configuration"""
|
|
with open(CONFIG_PATH, "r") as f:
|
|
return json.load(f)
|
|
|
|
|
|
@router.post("/pricing/override")
|
|
async def set_price_override(override: PriceOverride):
|
|
"""Set a user override for model pricing"""
|
|
with open(CONFIG_PATH, "r") as f:
|
|
config = json.load(f)
|
|
|
|
if "user_overrides" not in config:
|
|
config["user_overrides"] = {}
|
|
|
|
config["user_overrides"][override.model] = {
|
|
"input": override.input,
|
|
"output": override.output,
|
|
"description": override.description or f"User override for {override.model}",
|
|
"provider": "custom"
|
|
}
|
|
|
|
save_pricing(config)
|
|
|
|
return {
|
|
"message": f"Price override set for {override.model}",
|
|
"override": config["user_overrides"][override.model]
|
|
}
|
|
|
|
|
|
@router.delete("/pricing/override/{model}")
|
|
async def delete_price_override(model: str):
|
|
"""Remove a user override for model pricing"""
|
|
with open(CONFIG_PATH, "r") as f:
|
|
config = json.load(f)
|
|
|
|
if "user_overrides" in config and model in config["user_overrides"]:
|
|
del config["user_overrides"][model]
|
|
save_pricing(config)
|
|
return {"message": f"Override removed for {model}"}
|
|
|
|
raise HTTPException(status_code=404, detail=f"No override found for {model}")
|
|
|
|
|
|
@router.post("/pricing/add-model")
|
|
async def add_custom_model(override: PriceOverride):
|
|
"""Add a completely new custom model"""
|
|
with open(CONFIG_PATH, "r") as f:
|
|
config = json.load(f)
|
|
|
|
if "user_overrides" not in config:
|
|
config["user_overrides"] = {}
|
|
|
|
config["user_overrides"][override.model] = {
|
|
"input": override.input,
|
|
"output": override.output,
|
|
"description": override.description or f"Custom model: {override.model}",
|
|
"provider": "custom",
|
|
"context_window": 0
|
|
}
|
|
|
|
save_pricing(config)
|
|
|
|
return {
|
|
"message": f"Custom model {override.model} added",
|
|
"model": config["user_overrides"][override.model]
|
|
}
|