"""Model Drift Monitor Router - Detect distribution shifts in ML features""" from fastapi import APIRouter, UploadFile, File, HTTPException, Form from pydantic import BaseModel from typing import Optional import numpy as np import duckdb import tempfile import os import json from datetime import datetime import hashlib router = APIRouter() # In-memory storage for baselines and history baselines_store: dict = {} drift_history: list = [] class DriftThresholds(BaseModel): psi_threshold: float = 0.2 # PSI > 0.2 indicates significant drift ks_threshold: float = 0.05 # KS p-value < 0.05 indicates drift alert_enabled: bool = True class FeatureDrift(BaseModel): feature: str psi_score: float ks_statistic: float ks_pvalue: float is_drifted: bool drift_type: str # "none", "minor", "moderate", "severe" baseline_stats: dict current_stats: dict class DriftResult(BaseModel): is_drifted: bool overall_score: float drift_severity: str drifted_features: int total_features: int feature_scores: list[FeatureDrift] method: str recommendations: list[str] timestamp: str engine: str = "DuckDB" # Current thresholds (in-memory, could be persisted) current_thresholds = DriftThresholds() async def read_to_duckdb(file: UploadFile) -> tuple[duckdb.DuckDBPyConnection, str]: """Read uploaded file into DuckDB in-memory database""" content = await file.read() filename = file.filename.lower() if file.filename else "file.csv" conn = duckdb.connect(":memory:") # Write to temp file for DuckDB to read suffix = '.csv' if filename.endswith('.csv') else '.json' if filename.endswith('.json') else '.csv' with tempfile.NamedTemporaryFile(mode='wb', suffix=suffix, delete=False) as tmp: tmp.write(content) tmp_path = tmp.name try: if filename.endswith('.csv'): conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')") elif filename.endswith('.json'): conn.execute(f"CREATE TABLE data AS SELECT * FROM read_json_auto('{tmp_path}')") else: conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')") finally: os.unlink(tmp_path) return conn, "data" def get_numeric_columns(conn: duckdb.DuckDBPyConnection, table_name: str) -> list[str]: """Get list of numeric columns from table""" schema = conn.execute(f"DESCRIBE {table_name}").fetchall() numeric_types = ['INTEGER', 'BIGINT', 'DOUBLE', 'FLOAT', 'DECIMAL', 'REAL', 'SMALLINT', 'TINYINT', 'HUGEINT'] return [col[0] for col in schema if any(t in col[1].upper() for t in numeric_types)] def calculate_psi(baseline_values: np.ndarray, current_values: np.ndarray, bins: int = 10) -> float: """ Calculate Population Stability Index (PSI) PSI < 0.1: No significant change 0.1 <= PSI < 0.2: Moderate change, monitoring needed PSI >= 0.2: Significant change, action required """ # Remove NaN values baseline_clean = baseline_values[~np.isnan(baseline_values)] current_clean = current_values[~np.isnan(current_values)] if len(baseline_clean) == 0 or len(current_clean) == 0: return 0.0 # Create bins based on baseline distribution min_val = min(baseline_clean.min(), current_clean.min()) max_val = max(baseline_clean.max(), current_clean.max()) if min_val == max_val: return 0.0 bin_edges = np.linspace(min_val, max_val, bins + 1) # Calculate proportions for each bin baseline_counts, _ = np.histogram(baseline_clean, bins=bin_edges) current_counts, _ = np.histogram(current_clean, bins=bin_edges) # Convert to proportions (add small epsilon to avoid division by zero) epsilon = 1e-6 baseline_prop = (baseline_counts + epsilon) / (len(baseline_clean) + epsilon * bins) current_prop = (current_counts + epsilon) / (len(current_clean) + epsilon * bins) # Calculate PSI psi = np.sum((current_prop - baseline_prop) * np.log(current_prop / baseline_prop)) return float(psi) def calculate_ks_statistic(baseline_values: np.ndarray, current_values: np.ndarray) -> tuple[float, float]: """ Calculate Kolmogorov-Smirnov statistic and approximate p-value """ # Remove NaN values baseline_clean = baseline_values[~np.isnan(baseline_values)] current_clean = current_values[~np.isnan(current_values)] if len(baseline_clean) == 0 or len(current_clean) == 0: return 0.0, 1.0 # Sort both arrays baseline_sorted = np.sort(baseline_clean) current_sorted = np.sort(current_clean) # Create combined array of all values all_values = np.concatenate([baseline_sorted, current_sorted]) all_values = np.sort(np.unique(all_values)) # Calculate CDFs baseline_cdf = np.searchsorted(baseline_sorted, all_values, side='right') / len(baseline_sorted) current_cdf = np.searchsorted(current_sorted, all_values, side='right') / len(current_sorted) # KS statistic is the maximum difference ks_stat = float(np.max(np.abs(baseline_cdf - current_cdf))) # Approximate p-value using asymptotic formula n1, n2 = len(baseline_clean), len(current_clean) en = np.sqrt(n1 * n2 / (n1 + n2)) # Kolmogorov distribution approximation lambda_val = (en + 0.12 + 0.11 / en) * ks_stat # Two-sided p-value approximation if lambda_val < 0.001: p_value = 1.0 else: # Approximation using exponential terms j = np.arange(1, 101) p_value = 2 * np.sum((-1) ** (j - 1) * np.exp(-2 * j ** 2 * lambda_val ** 2)) p_value = max(0.0, min(1.0, p_value)) return ks_stat, float(p_value) def get_column_stats(conn: duckdb.DuckDBPyConnection, table_name: str, column: str) -> dict: """Get statistics for a column using DuckDB""" try: stats = conn.execute(f''' SELECT COUNT(*) as count, COUNT("{column}") as non_null, AVG("{column}"::DOUBLE) as mean, STDDEV("{column}"::DOUBLE) as std, MIN("{column}"::DOUBLE) as min, MAX("{column}"::DOUBLE) as max, PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY "{column}") as q1, PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY "{column}") as median, PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{column}") as q3 FROM {table_name} ''').fetchone() return { "count": stats[0], "non_null": stats[1], "mean": float(stats[2]) if stats[2] is not None else None, "std": float(stats[3]) if stats[3] is not None else None, "min": float(stats[4]) if stats[4] is not None else None, "max": float(stats[5]) if stats[5] is not None else None, "q1": float(stats[6]) if stats[6] is not None else None, "median": float(stats[7]) if stats[7] is not None else None, "q3": float(stats[8]) if stats[8] is not None else None } except Exception: return {"count": 0, "non_null": 0} def classify_drift(psi: float, ks_pvalue: float, psi_threshold: float, ks_threshold: float) -> tuple[bool, str]: """Classify drift severity based on PSI and KS test""" is_drifted = psi >= psi_threshold or ks_pvalue < ks_threshold if psi >= 0.25 or ks_pvalue < 0.01: return True, "severe" elif psi >= 0.2 or ks_pvalue < 0.05: return True, "moderate" elif psi >= 0.1: return True, "minor" else: return is_drifted, "none" def generate_recommendations(feature_scores: list[FeatureDrift], overall_drifted: bool) -> list[str]: """Generate actionable recommendations based on drift analysis""" recommendations = [] severe_features = [f.feature for f in feature_scores if f.drift_type == "severe"] moderate_features = [f.feature for f in feature_scores if f.drift_type == "moderate"] minor_features = [f.feature for f in feature_scores if f.drift_type == "minor"] if severe_features: recommendations.append(f"🚨 CRITICAL: Severe drift detected in {len(severe_features)} feature(s): {', '.join(severe_features[:5])}. Immediate model retraining recommended.") recommendations.append("Consider rolling back to a previous model version if performance degradation is observed.") if moderate_features: recommendations.append(f"âš ī¸ WARNING: Moderate drift in {len(moderate_features)} feature(s): {', '.join(moderate_features[:5])}. Schedule model retraining within 1-2 weeks.") recommendations.append("Monitor model performance metrics closely for these features.") if minor_features: recommendations.append(f"â„šī¸ INFO: Minor drift detected in {len(minor_features)} feature(s). Continue monitoring.") if overall_drifted: recommendations.append("Update baseline distributions after addressing drift to reset monitoring.") recommendations.append("Investigate data pipeline changes that may have caused distribution shifts.") recommendations.append("Consider feature engineering adjustments for drifted features.") else: recommendations.append("✅ No significant drift detected. Model distributions are stable.") recommendations.append("Continue regular monitoring at current frequency.") return recommendations @router.post("/baseline") async def upload_baseline( file: UploadFile = File(...), name: Optional[str] = Form(None) ): """Upload baseline distribution for comparison""" try: conn, table_name = await read_to_duckdb(file) numeric_cols = get_numeric_columns(conn, table_name) if not numeric_cols: raise HTTPException(status_code=400, detail="No numeric columns found in the dataset") # Generate baseline ID baseline_id = hashlib.md5(f"{file.filename}_{datetime.now().isoformat()}".encode()).hexdigest()[:12] # Store baseline statistics and raw values for each column baseline_data = { "id": baseline_id, "name": name or file.filename, "filename": file.filename, "created_at": datetime.now().isoformat(), "row_count": conn.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0], "columns": {}, "values": {} } for col in numeric_cols: baseline_data["columns"][col] = get_column_stats(conn, table_name, col) # Store actual values for PSI/KS calculation values = conn.execute(f'SELECT "{col}"::DOUBLE FROM {table_name} WHERE "{col}" IS NOT NULL').fetchall() baseline_data["values"][col] = np.array([v[0] for v in values]) baselines_store[baseline_id] = baseline_data conn.close() return { "message": "Baseline uploaded successfully", "baseline_id": baseline_id, "name": baseline_data["name"], "filename": file.filename, "row_count": baseline_data["row_count"], "numeric_columns": numeric_cols, "column_stats": baseline_data["columns"], "engine": "DuckDB" } except HTTPException: raise except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing baseline file: {str(e)}") @router.get("/baselines") async def list_baselines(): """List all stored baselines""" return { "baselines": [ { "id": b["id"], "name": b["name"], "filename": b["filename"], "created_at": b["created_at"], "row_count": b["row_count"], "columns": list(b["columns"].keys()) } for b in baselines_store.values() ] } @router.delete("/baseline/{baseline_id}") async def delete_baseline(baseline_id: str): """Delete a stored baseline""" if baseline_id not in baselines_store: raise HTTPException(status_code=404, detail="Baseline not found") del baselines_store[baseline_id] return {"message": "Baseline deleted", "baseline_id": baseline_id} @router.post("/analyze") async def analyze_drift( production_file: UploadFile = File(...), baseline_id: str = Form(...) ): """Analyze production data for drift against baseline""" if baseline_id not in baselines_store: raise HTTPException(status_code=404, detail=f"Baseline '{baseline_id}' not found. Upload a baseline first.") try: baseline = baselines_store[baseline_id] conn, table_name = await read_to_duckdb(production_file) numeric_cols = get_numeric_columns(conn, table_name) common_cols = [col for col in numeric_cols if col in baseline["columns"]] if not common_cols: raise HTTPException(status_code=400, detail="No matching numeric columns found between production data and baseline") feature_scores = [] total_psi = 0.0 drifted_count = 0 for col in common_cols: # Get current values current_values = conn.execute(f'SELECT "{col}"::DOUBLE FROM {table_name} WHERE "{col}" IS NOT NULL').fetchall() current_arr = np.array([v[0] for v in current_values]) baseline_arr = baseline["values"][col] # Calculate drift metrics psi = calculate_psi(baseline_arr, current_arr) ks_stat, ks_pvalue = calculate_ks_statistic(baseline_arr, current_arr) # Classify drift is_drifted, drift_type = classify_drift(psi, ks_pvalue, current_thresholds.psi_threshold, current_thresholds.ks_threshold) if is_drifted: drifted_count += 1 total_psi += psi feature_scores.append(FeatureDrift( feature=col, psi_score=round(psi, 4), ks_statistic=round(ks_stat, 4), ks_pvalue=round(ks_pvalue, 4), is_drifted=is_drifted, drift_type=drift_type, baseline_stats=baseline["columns"][col], current_stats=get_column_stats(conn, table_name, col) )) conn.close() # Calculate overall drift avg_psi = total_psi / len(common_cols) if common_cols else 0 overall_drifted = drifted_count > 0 # Determine severity severe_count = len([f for f in feature_scores if f.drift_type == "severe"]) moderate_count = len([f for f in feature_scores if f.drift_type == "moderate"]) if severe_count > 0: drift_severity = "severe" elif moderate_count > 0: drift_severity = "moderate" elif drifted_count > 0: drift_severity = "minor" else: drift_severity = "none" # Generate recommendations recommendations = generate_recommendations(feature_scores, overall_drifted) # Create result result = DriftResult( is_drifted=overall_drifted, overall_score=round(avg_psi, 4), drift_severity=drift_severity, drifted_features=drifted_count, total_features=len(common_cols), feature_scores=feature_scores, method="PSI + Kolmogorov-Smirnov", recommendations=recommendations, timestamp=datetime.now().isoformat(), engine="DuckDB" ) # Store in history drift_history.append({ "baseline_id": baseline_id, "production_file": production_file.filename, "timestamp": result.timestamp, "is_drifted": result.is_drifted, "overall_score": result.overall_score, "drift_severity": result.drift_severity, "drifted_features": result.drifted_features, "total_features": result.total_features }) return result except HTTPException: raise except Exception as e: raise HTTPException(status_code=400, detail=f"Error analyzing drift: {str(e)}") @router.post("/compare-files") async def compare_two_files( baseline_file: UploadFile = File(...), production_file: UploadFile = File(...) ): """Compare two files directly without storing baseline""" try: # Load both files baseline_conn, baseline_table = await read_to_duckdb(baseline_file) # Need to reset file position for second read production_content = await production_file.read() # Create production connection prod_conn = duckdb.connect(":memory:") filename = production_file.filename.lower() if production_file.filename else "file.csv" suffix = '.csv' if filename.endswith('.csv') else '.json' if filename.endswith('.json') else '.csv' with tempfile.NamedTemporaryFile(mode='wb', suffix=suffix, delete=False) as tmp: tmp.write(production_content) tmp_path = tmp.name try: if filename.endswith('.csv'): prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')") elif filename.endswith('.json'): prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_json_auto('{tmp_path}')") else: prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')") finally: os.unlink(tmp_path) prod_table = "data" # Get common numeric columns baseline_cols = get_numeric_columns(baseline_conn, baseline_table) prod_cols = get_numeric_columns(prod_conn, prod_table) common_cols = list(set(baseline_cols) & set(prod_cols)) if not common_cols: raise HTTPException(status_code=400, detail="No matching numeric columns found between files") feature_scores = [] total_psi = 0.0 drifted_count = 0 for col in common_cols: # Get values from both files baseline_values = baseline_conn.execute(f'SELECT "{col}"::DOUBLE FROM {baseline_table} WHERE "{col}" IS NOT NULL').fetchall() prod_values = prod_conn.execute(f'SELECT "{col}"::DOUBLE FROM {prod_table} WHERE "{col}" IS NOT NULL').fetchall() baseline_arr = np.array([v[0] for v in baseline_values]) prod_arr = np.array([v[0] for v in prod_values]) # Calculate drift metrics psi = calculate_psi(baseline_arr, prod_arr) ks_stat, ks_pvalue = calculate_ks_statistic(baseline_arr, prod_arr) # Classify drift is_drifted, drift_type = classify_drift(psi, ks_pvalue, current_thresholds.psi_threshold, current_thresholds.ks_threshold) if is_drifted: drifted_count += 1 total_psi += psi feature_scores.append(FeatureDrift( feature=col, psi_score=round(psi, 4), ks_statistic=round(ks_stat, 4), ks_pvalue=round(ks_pvalue, 4), is_drifted=is_drifted, drift_type=drift_type, baseline_stats=get_column_stats(baseline_conn, baseline_table, col), current_stats=get_column_stats(prod_conn, prod_table, col) )) baseline_conn.close() prod_conn.close() # Calculate overall drift avg_psi = total_psi / len(common_cols) if common_cols else 0 overall_drifted = drifted_count > 0 # Determine severity severe_count = len([f for f in feature_scores if f.drift_type == "severe"]) moderate_count = len([f for f in feature_scores if f.drift_type == "moderate"]) if severe_count > 0: drift_severity = "severe" elif moderate_count > 0: drift_severity = "moderate" elif drifted_count > 0: drift_severity = "minor" else: drift_severity = "none" recommendations = generate_recommendations(feature_scores, overall_drifted) return DriftResult( is_drifted=overall_drifted, overall_score=round(avg_psi, 4), drift_severity=drift_severity, drifted_features=drifted_count, total_features=len(common_cols), feature_scores=feature_scores, method="PSI + Kolmogorov-Smirnov", recommendations=recommendations, timestamp=datetime.now().isoformat(), engine="DuckDB" ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=400, detail=f"Error comparing files: {str(e)}") @router.get("/history") async def get_drift_history(limit: int = 100): """Get historical drift analysis results""" return { "history": drift_history[-limit:], "total_analyses": len(drift_history) } @router.put("/thresholds") async def update_thresholds(thresholds: DriftThresholds): """Update drift detection thresholds""" global current_thresholds current_thresholds = thresholds return { "message": "Thresholds updated", "thresholds": { "psi_threshold": current_thresholds.psi_threshold, "ks_threshold": current_thresholds.ks_threshold, "alert_enabled": current_thresholds.alert_enabled } } @router.get("/thresholds") async def get_thresholds(): """Get current drift detection thresholds""" return { "psi_threshold": current_thresholds.psi_threshold, "ks_threshold": current_thresholds.ks_threshold, "alert_enabled": current_thresholds.alert_enabled, "psi_interpretation": { "low": "PSI < 0.1 - No significant change", "moderate": "0.1 <= PSI < 0.2 - Moderate change, monitoring needed", "high": "PSI >= 0.2 - Significant change, action required" } }