ai-tools-suite/backend/routers/eda.py
2025-12-27 15:33:06 +00:00

277 lines
8.4 KiB
Python

"""EDA Router - Gapminder Exploratory Data Analysis API"""
from fastapi import APIRouter, Query, HTTPException
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import pandas as pd
import numpy as np
from pathlib import Path
router = APIRouter()
# Load data once at startup
DATA_PATH = Path(__file__).parent.parent / "data" / "gapminder.tsv"
def load_gapminder() -> pd.DataFrame:
"""Load gapminder dataset"""
return pd.read_csv(DATA_PATH, sep='\t')
# Cache the dataframe
_df: pd.DataFrame = None
def get_df() -> pd.DataFrame:
global _df
if _df is None:
_df = load_gapminder()
return _df
# ========== PYDANTIC MODELS ==========
class DataResponse(BaseModel):
data: List[Dict[str, Any]]
total: int
filters_applied: Dict[str, Any]
class StatisticsResponse(BaseModel):
column: str
count: int
mean: float
std: float
min: float
q25: float
median: float
q75: float
max: float
group_by: Optional[str] = None
grouped_stats: Optional[Dict[str, Dict[str, float]]] = None
class CorrelationResponse(BaseModel):
columns: List[str]
matrix: List[List[float]]
class TimeseriesResponse(BaseModel):
metric: str
data: List[Dict[str, Any]]
class RankingResponse(BaseModel):
year: int
metric: str
top_n: int
data: List[Dict[str, Any]]
class MetadataResponse(BaseModel):
countries: List[str]
continents: List[str]
years: List[int]
columns: List[str]
total_rows: int
# ========== ENDPOINTS ==========
@router.get("/metadata", response_model=MetadataResponse)
async def get_metadata():
"""Get dataset metadata - available countries, continents, years"""
df = get_df()
return MetadataResponse(
countries=sorted(df['country'].unique().tolist()),
continents=sorted(df['continent'].unique().tolist()),
years=sorted(df['year'].unique().tolist()),
columns=df.columns.tolist(),
total_rows=len(df)
)
@router.get("/data", response_model=DataResponse)
async def get_data(
year: Optional[int] = Query(None, description="Filter by year"),
continent: Optional[str] = Query(None, description="Filter by continent"),
country: Optional[str] = Query(None, description="Filter by country"),
limit: Optional[int] = Query(None, description="Limit number of results")
):
"""Get filtered gapminder data"""
df = get_df().copy()
filters = {}
if year is not None:
df = df[df['year'] == year]
filters['year'] = year
if continent is not None:
df = df[df['continent'] == continent]
filters['continent'] = continent
if country is not None:
df = df[df['country'] == country]
filters['country'] = country
if limit is not None:
df = df.head(limit)
filters['limit'] = limit
return DataResponse(
data=df.to_dict(orient='records'),
total=len(df),
filters_applied=filters
)
@router.get("/statistics", response_model=StatisticsResponse)
async def get_statistics(
column: str = Query("lifeExp", description="Column to analyze (lifeExp, pop, gdpPercap)"),
group_by: Optional[str] = Query(None, description="Group by column (continent, year)"),
year: Optional[int] = Query(None, description="Filter by year first")
):
"""Get descriptive statistics for a numeric column"""
df = get_df().copy()
if column not in ['lifeExp', 'pop', 'gdpPercap']:
raise HTTPException(status_code=400, detail=f"Invalid column: {column}. Must be lifeExp, pop, or gdpPercap")
if year is not None:
df = df[df['year'] == year]
stats = df[column].describe()
result = StatisticsResponse(
column=column,
count=int(stats['count']),
mean=float(stats['mean']),
std=float(stats['std']),
min=float(stats['min']),
q25=float(stats['25%']),
median=float(stats['50%']),
q75=float(stats['75%']),
max=float(stats['max']),
group_by=group_by
)
if group_by is not None:
if group_by not in ['continent', 'year']:
raise HTTPException(status_code=400, detail=f"Invalid group_by: {group_by}. Must be continent or year")
grouped = df.groupby(group_by)[column].agg(['mean', 'std', 'min', 'max', 'count'])
grouped_stats = {}
for idx, row in grouped.iterrows():
grouped_stats[str(idx)] = {
'mean': float(row['mean']),
'std': float(row['std']) if not pd.isna(row['std']) else 0.0,
'min': float(row['min']),
'max': float(row['max']),
'count': int(row['count'])
}
result.grouped_stats = grouped_stats
return result
@router.get("/correlation", response_model=CorrelationResponse)
async def get_correlation(
year: Optional[int] = Query(None, description="Filter by year first")
):
"""Get correlation matrix for numeric columns"""
df = get_df().copy()
if year is not None:
df = df[df['year'] == year]
numeric_cols = ['lifeExp', 'pop', 'gdpPercap']
corr_matrix = df[numeric_cols].corr()
return CorrelationResponse(
columns=numeric_cols,
matrix=corr_matrix.values.tolist()
)
@router.get("/timeseries", response_model=TimeseriesResponse)
async def get_timeseries(
metric: str = Query("lifeExp", description="Metric to track (lifeExp, pop, gdpPercap)"),
countries: Optional[str] = Query(None, description="Comma-separated list of countries"),
continent: Optional[str] = Query(None, description="Filter by continent"),
top_n: Optional[int] = Query(None, description="Get top N countries by latest value")
):
"""Get time series data for animated charts"""
df = get_df().copy()
if metric not in ['lifeExp', 'pop', 'gdpPercap']:
raise HTTPException(status_code=400, detail=f"Invalid metric: {metric}")
if continent is not None:
df = df[df['continent'] == continent]
if countries is not None:
country_list = [c.strip() for c in countries.split(',')]
df = df[df['country'].isin(country_list)]
elif top_n is not None:
# Get top N countries by latest year value
latest_year = df['year'].max()
top_countries = df[df['year'] == latest_year].nlargest(top_n, metric)['country'].tolist()
df = df[df['country'].isin(top_countries)]
# Return data formatted for animation (all columns needed for bubble chart)
return TimeseriesResponse(
metric=metric,
data=df[['country', 'continent', 'year', 'lifeExp', 'pop', 'gdpPercap']].to_dict(orient='records')
)
@router.get("/ranking", response_model=RankingResponse)
async def get_ranking(
year: int = Query(2007, description="Year to rank"),
metric: str = Query("gdpPercap", description="Metric to rank by (lifeExp, pop, gdpPercap)"),
top_n: int = Query(15, description="Number of top countries to return"),
continent: Optional[str] = Query(None, description="Filter by continent")
):
"""Get ranked data for bar chart race"""
df = get_df().copy()
if metric not in ['lifeExp', 'pop', 'gdpPercap']:
raise HTTPException(status_code=400, detail=f"Invalid metric: {metric}")
df = df[df['year'] == year]
if continent is not None:
df = df[df['continent'] == continent]
df = df.nlargest(top_n, metric)
return RankingResponse(
year=year,
metric=metric,
top_n=top_n,
data=df[['country', 'continent', metric]].to_dict(orient='records')
)
@router.get("/all-years-ranking")
async def get_all_years_ranking(
metric: str = Query("gdpPercap", description="Metric to rank by"),
top_n: int = Query(10, description="Number of top countries per year")
):
"""Get rankings for all years (for bar chart race animation)"""
df = get_df().copy()
if metric not in ['lifeExp', 'pop', 'gdpPercap']:
raise HTTPException(status_code=400, detail=f"Invalid metric: {metric}")
years = sorted(df['year'].unique())
result = []
for year in years:
year_df = df[df['year'] == year].nlargest(top_n, metric)
for rank, (_, row) in enumerate(year_df.iterrows(), 1):
result.append({
'year': int(year),
'rank': rank,
'country': row['country'],
'continent': row['continent'],
'value': float(row[metric])
})
return {
'metric': metric,
'top_n': top_n,
'years': years,
'data': result
}