277 lines
8.4 KiB
Python
277 lines
8.4 KiB
Python
"""EDA Router - Gapminder Exploratory Data Analysis API"""
|
|
from fastapi import APIRouter, Query, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import Optional, List, Dict, Any
|
|
import pandas as pd
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
router = APIRouter()
|
|
|
|
# Load data once at startup
|
|
DATA_PATH = Path(__file__).parent.parent / "data" / "gapminder.tsv"
|
|
|
|
def load_gapminder() -> pd.DataFrame:
|
|
"""Load gapminder dataset"""
|
|
return pd.read_csv(DATA_PATH, sep='\t')
|
|
|
|
# Cache the dataframe
|
|
_df: pd.DataFrame = None
|
|
|
|
def get_df() -> pd.DataFrame:
|
|
global _df
|
|
if _df is None:
|
|
_df = load_gapminder()
|
|
return _df
|
|
|
|
|
|
# ========== PYDANTIC MODELS ==========
|
|
|
|
class DataResponse(BaseModel):
|
|
data: List[Dict[str, Any]]
|
|
total: int
|
|
filters_applied: Dict[str, Any]
|
|
|
|
class StatisticsResponse(BaseModel):
|
|
column: str
|
|
count: int
|
|
mean: float
|
|
std: float
|
|
min: float
|
|
q25: float
|
|
median: float
|
|
q75: float
|
|
max: float
|
|
group_by: Optional[str] = None
|
|
grouped_stats: Optional[Dict[str, Dict[str, float]]] = None
|
|
|
|
class CorrelationResponse(BaseModel):
|
|
columns: List[str]
|
|
matrix: List[List[float]]
|
|
|
|
class TimeseriesResponse(BaseModel):
|
|
metric: str
|
|
data: List[Dict[str, Any]]
|
|
|
|
class RankingResponse(BaseModel):
|
|
year: int
|
|
metric: str
|
|
top_n: int
|
|
data: List[Dict[str, Any]]
|
|
|
|
class MetadataResponse(BaseModel):
|
|
countries: List[str]
|
|
continents: List[str]
|
|
years: List[int]
|
|
columns: List[str]
|
|
total_rows: int
|
|
|
|
|
|
# ========== ENDPOINTS ==========
|
|
|
|
@router.get("/metadata", response_model=MetadataResponse)
|
|
async def get_metadata():
|
|
"""Get dataset metadata - available countries, continents, years"""
|
|
df = get_df()
|
|
return MetadataResponse(
|
|
countries=sorted(df['country'].unique().tolist()),
|
|
continents=sorted(df['continent'].unique().tolist()),
|
|
years=sorted(df['year'].unique().tolist()),
|
|
columns=df.columns.tolist(),
|
|
total_rows=len(df)
|
|
)
|
|
|
|
|
|
@router.get("/data", response_model=DataResponse)
|
|
async def get_data(
|
|
year: Optional[int] = Query(None, description="Filter by year"),
|
|
continent: Optional[str] = Query(None, description="Filter by continent"),
|
|
country: Optional[str] = Query(None, description="Filter by country"),
|
|
limit: Optional[int] = Query(None, description="Limit number of results")
|
|
):
|
|
"""Get filtered gapminder data"""
|
|
df = get_df().copy()
|
|
filters = {}
|
|
|
|
if year is not None:
|
|
df = df[df['year'] == year]
|
|
filters['year'] = year
|
|
|
|
if continent is not None:
|
|
df = df[df['continent'] == continent]
|
|
filters['continent'] = continent
|
|
|
|
if country is not None:
|
|
df = df[df['country'] == country]
|
|
filters['country'] = country
|
|
|
|
if limit is not None:
|
|
df = df.head(limit)
|
|
filters['limit'] = limit
|
|
|
|
return DataResponse(
|
|
data=df.to_dict(orient='records'),
|
|
total=len(df),
|
|
filters_applied=filters
|
|
)
|
|
|
|
|
|
@router.get("/statistics", response_model=StatisticsResponse)
|
|
async def get_statistics(
|
|
column: str = Query("lifeExp", description="Column to analyze (lifeExp, pop, gdpPercap)"),
|
|
group_by: Optional[str] = Query(None, description="Group by column (continent, year)"),
|
|
year: Optional[int] = Query(None, description="Filter by year first")
|
|
):
|
|
"""Get descriptive statistics for a numeric column"""
|
|
df = get_df().copy()
|
|
|
|
if column not in ['lifeExp', 'pop', 'gdpPercap']:
|
|
raise HTTPException(status_code=400, detail=f"Invalid column: {column}. Must be lifeExp, pop, or gdpPercap")
|
|
|
|
if year is not None:
|
|
df = df[df['year'] == year]
|
|
|
|
stats = df[column].describe()
|
|
|
|
result = StatisticsResponse(
|
|
column=column,
|
|
count=int(stats['count']),
|
|
mean=float(stats['mean']),
|
|
std=float(stats['std']),
|
|
min=float(stats['min']),
|
|
q25=float(stats['25%']),
|
|
median=float(stats['50%']),
|
|
q75=float(stats['75%']),
|
|
max=float(stats['max']),
|
|
group_by=group_by
|
|
)
|
|
|
|
if group_by is not None:
|
|
if group_by not in ['continent', 'year']:
|
|
raise HTTPException(status_code=400, detail=f"Invalid group_by: {group_by}. Must be continent or year")
|
|
|
|
grouped = df.groupby(group_by)[column].agg(['mean', 'std', 'min', 'max', 'count'])
|
|
grouped_stats = {}
|
|
for idx, row in grouped.iterrows():
|
|
grouped_stats[str(idx)] = {
|
|
'mean': float(row['mean']),
|
|
'std': float(row['std']) if not pd.isna(row['std']) else 0.0,
|
|
'min': float(row['min']),
|
|
'max': float(row['max']),
|
|
'count': int(row['count'])
|
|
}
|
|
result.grouped_stats = grouped_stats
|
|
|
|
return result
|
|
|
|
|
|
@router.get("/correlation", response_model=CorrelationResponse)
|
|
async def get_correlation(
|
|
year: Optional[int] = Query(None, description="Filter by year first")
|
|
):
|
|
"""Get correlation matrix for numeric columns"""
|
|
df = get_df().copy()
|
|
|
|
if year is not None:
|
|
df = df[df['year'] == year]
|
|
|
|
numeric_cols = ['lifeExp', 'pop', 'gdpPercap']
|
|
corr_matrix = df[numeric_cols].corr()
|
|
|
|
return CorrelationResponse(
|
|
columns=numeric_cols,
|
|
matrix=corr_matrix.values.tolist()
|
|
)
|
|
|
|
|
|
@router.get("/timeseries", response_model=TimeseriesResponse)
|
|
async def get_timeseries(
|
|
metric: str = Query("lifeExp", description="Metric to track (lifeExp, pop, gdpPercap)"),
|
|
countries: Optional[str] = Query(None, description="Comma-separated list of countries"),
|
|
continent: Optional[str] = Query(None, description="Filter by continent"),
|
|
top_n: Optional[int] = Query(None, description="Get top N countries by latest value")
|
|
):
|
|
"""Get time series data for animated charts"""
|
|
df = get_df().copy()
|
|
|
|
if metric not in ['lifeExp', 'pop', 'gdpPercap']:
|
|
raise HTTPException(status_code=400, detail=f"Invalid metric: {metric}")
|
|
|
|
if continent is not None:
|
|
df = df[df['continent'] == continent]
|
|
|
|
if countries is not None:
|
|
country_list = [c.strip() for c in countries.split(',')]
|
|
df = df[df['country'].isin(country_list)]
|
|
elif top_n is not None:
|
|
# Get top N countries by latest year value
|
|
latest_year = df['year'].max()
|
|
top_countries = df[df['year'] == latest_year].nlargest(top_n, metric)['country'].tolist()
|
|
df = df[df['country'].isin(top_countries)]
|
|
|
|
# Return data formatted for animation (all columns needed for bubble chart)
|
|
return TimeseriesResponse(
|
|
metric=metric,
|
|
data=df[['country', 'continent', 'year', 'lifeExp', 'pop', 'gdpPercap']].to_dict(orient='records')
|
|
)
|
|
|
|
|
|
@router.get("/ranking", response_model=RankingResponse)
|
|
async def get_ranking(
|
|
year: int = Query(2007, description="Year to rank"),
|
|
metric: str = Query("gdpPercap", description="Metric to rank by (lifeExp, pop, gdpPercap)"),
|
|
top_n: int = Query(15, description="Number of top countries to return"),
|
|
continent: Optional[str] = Query(None, description="Filter by continent")
|
|
):
|
|
"""Get ranked data for bar chart race"""
|
|
df = get_df().copy()
|
|
|
|
if metric not in ['lifeExp', 'pop', 'gdpPercap']:
|
|
raise HTTPException(status_code=400, detail=f"Invalid metric: {metric}")
|
|
|
|
df = df[df['year'] == year]
|
|
|
|
if continent is not None:
|
|
df = df[df['continent'] == continent]
|
|
|
|
df = df.nlargest(top_n, metric)
|
|
|
|
return RankingResponse(
|
|
year=year,
|
|
metric=metric,
|
|
top_n=top_n,
|
|
data=df[['country', 'continent', metric]].to_dict(orient='records')
|
|
)
|
|
|
|
|
|
@router.get("/all-years-ranking")
|
|
async def get_all_years_ranking(
|
|
metric: str = Query("gdpPercap", description="Metric to rank by"),
|
|
top_n: int = Query(10, description="Number of top countries per year")
|
|
):
|
|
"""Get rankings for all years (for bar chart race animation)"""
|
|
df = get_df().copy()
|
|
|
|
if metric not in ['lifeExp', 'pop', 'gdpPercap']:
|
|
raise HTTPException(status_code=400, detail=f"Invalid metric: {metric}")
|
|
|
|
years = sorted(df['year'].unique())
|
|
result = []
|
|
|
|
for year in years:
|
|
year_df = df[df['year'] == year].nlargest(top_n, metric)
|
|
for rank, (_, row) in enumerate(year_df.iterrows(), 1):
|
|
result.append({
|
|
'year': int(year),
|
|
'rank': rank,
|
|
'country': row['country'],
|
|
'continent': row['continent'],
|
|
'value': float(row[metric])
|
|
})
|
|
|
|
return {
|
|
'metric': metric,
|
|
'top_n': top_n,
|
|
'years': years,
|
|
'data': result
|
|
}
|