ai-tools-suite/backend/routers/drift.py
2025-12-27 15:33:06 +00:00

589 lines
22 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Model Drift Monitor Router - Detect distribution shifts in ML features"""
from fastapi import APIRouter, UploadFile, File, HTTPException, Form
from pydantic import BaseModel
from typing import Optional
import numpy as np
import duckdb
import tempfile
import os
import json
from datetime import datetime
import hashlib
router = APIRouter()
# In-memory storage for baselines and history
baselines_store: dict = {}
drift_history: list = []
class DriftThresholds(BaseModel):
psi_threshold: float = 0.2 # PSI > 0.2 indicates significant drift
ks_threshold: float = 0.05 # KS p-value < 0.05 indicates drift
alert_enabled: bool = True
class FeatureDrift(BaseModel):
feature: str
psi_score: float
ks_statistic: float
ks_pvalue: float
is_drifted: bool
drift_type: str # "none", "minor", "moderate", "severe"
baseline_stats: dict
current_stats: dict
class DriftResult(BaseModel):
is_drifted: bool
overall_score: float
drift_severity: str
drifted_features: int
total_features: int
feature_scores: list[FeatureDrift]
method: str
recommendations: list[str]
timestamp: str
engine: str = "DuckDB"
# Current thresholds (in-memory, could be persisted)
current_thresholds = DriftThresholds()
async def read_to_duckdb(file: UploadFile) -> tuple[duckdb.DuckDBPyConnection, str]:
"""Read uploaded file into DuckDB in-memory database"""
content = await file.read()
filename = file.filename.lower() if file.filename else "file.csv"
conn = duckdb.connect(":memory:")
# Write to temp file for DuckDB to read
suffix = '.csv' if filename.endswith('.csv') else '.json' if filename.endswith('.json') else '.csv'
with tempfile.NamedTemporaryFile(mode='wb', suffix=suffix, delete=False) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
if filename.endswith('.csv'):
conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
elif filename.endswith('.json'):
conn.execute(f"CREATE TABLE data AS SELECT * FROM read_json_auto('{tmp_path}')")
else:
conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
finally:
os.unlink(tmp_path)
return conn, "data"
def get_numeric_columns(conn: duckdb.DuckDBPyConnection, table_name: str) -> list[str]:
"""Get list of numeric columns from table"""
schema = conn.execute(f"DESCRIBE {table_name}").fetchall()
numeric_types = ['INTEGER', 'BIGINT', 'DOUBLE', 'FLOAT', 'DECIMAL', 'REAL', 'SMALLINT', 'TINYINT', 'HUGEINT']
return [col[0] for col in schema if any(t in col[1].upper() for t in numeric_types)]
def calculate_psi(baseline_values: np.ndarray, current_values: np.ndarray, bins: int = 10) -> float:
"""
Calculate Population Stability Index (PSI)
PSI < 0.1: No significant change
0.1 <= PSI < 0.2: Moderate change, monitoring needed
PSI >= 0.2: Significant change, action required
"""
# Remove NaN values
baseline_clean = baseline_values[~np.isnan(baseline_values)]
current_clean = current_values[~np.isnan(current_values)]
if len(baseline_clean) == 0 or len(current_clean) == 0:
return 0.0
# Create bins based on baseline distribution
min_val = min(baseline_clean.min(), current_clean.min())
max_val = max(baseline_clean.max(), current_clean.max())
if min_val == max_val:
return 0.0
bin_edges = np.linspace(min_val, max_val, bins + 1)
# Calculate proportions for each bin
baseline_counts, _ = np.histogram(baseline_clean, bins=bin_edges)
current_counts, _ = np.histogram(current_clean, bins=bin_edges)
# Convert to proportions (add small epsilon to avoid division by zero)
epsilon = 1e-6
baseline_prop = (baseline_counts + epsilon) / (len(baseline_clean) + epsilon * bins)
current_prop = (current_counts + epsilon) / (len(current_clean) + epsilon * bins)
# Calculate PSI
psi = np.sum((current_prop - baseline_prop) * np.log(current_prop / baseline_prop))
return float(psi)
def calculate_ks_statistic(baseline_values: np.ndarray, current_values: np.ndarray) -> tuple[float, float]:
"""
Calculate Kolmogorov-Smirnov statistic and approximate p-value
"""
# Remove NaN values
baseline_clean = baseline_values[~np.isnan(baseline_values)]
current_clean = current_values[~np.isnan(current_values)]
if len(baseline_clean) == 0 or len(current_clean) == 0:
return 0.0, 1.0
# Sort both arrays
baseline_sorted = np.sort(baseline_clean)
current_sorted = np.sort(current_clean)
# Create combined array of all values
all_values = np.concatenate([baseline_sorted, current_sorted])
all_values = np.sort(np.unique(all_values))
# Calculate CDFs
baseline_cdf = np.searchsorted(baseline_sorted, all_values, side='right') / len(baseline_sorted)
current_cdf = np.searchsorted(current_sorted, all_values, side='right') / len(current_sorted)
# KS statistic is the maximum difference
ks_stat = float(np.max(np.abs(baseline_cdf - current_cdf)))
# Approximate p-value using asymptotic formula
n1, n2 = len(baseline_clean), len(current_clean)
en = np.sqrt(n1 * n2 / (n1 + n2))
# Kolmogorov distribution approximation
lambda_val = (en + 0.12 + 0.11 / en) * ks_stat
# Two-sided p-value approximation
if lambda_val < 0.001:
p_value = 1.0
else:
# Approximation using exponential terms
j = np.arange(1, 101)
p_value = 2 * np.sum((-1) ** (j - 1) * np.exp(-2 * j ** 2 * lambda_val ** 2))
p_value = max(0.0, min(1.0, p_value))
return ks_stat, float(p_value)
def get_column_stats(conn: duckdb.DuckDBPyConnection, table_name: str, column: str) -> dict:
"""Get statistics for a column using DuckDB"""
try:
stats = conn.execute(f'''
SELECT
COUNT(*) as count,
COUNT("{column}") as non_null,
AVG("{column}"::DOUBLE) as mean,
STDDEV("{column}"::DOUBLE) as std,
MIN("{column}"::DOUBLE) as min,
MAX("{column}"::DOUBLE) as max,
PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY "{column}") as q1,
PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY "{column}") as median,
PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{column}") as q3
FROM {table_name}
''').fetchone()
return {
"count": stats[0],
"non_null": stats[1],
"mean": float(stats[2]) if stats[2] is not None else None,
"std": float(stats[3]) if stats[3] is not None else None,
"min": float(stats[4]) if stats[4] is not None else None,
"max": float(stats[5]) if stats[5] is not None else None,
"q1": float(stats[6]) if stats[6] is not None else None,
"median": float(stats[7]) if stats[7] is not None else None,
"q3": float(stats[8]) if stats[8] is not None else None
}
except Exception:
return {"count": 0, "non_null": 0}
def classify_drift(psi: float, ks_pvalue: float, psi_threshold: float, ks_threshold: float) -> tuple[bool, str]:
"""Classify drift severity based on PSI and KS test"""
is_drifted = psi >= psi_threshold or ks_pvalue < ks_threshold
if psi >= 0.25 or ks_pvalue < 0.01:
return True, "severe"
elif psi >= 0.2 or ks_pvalue < 0.05:
return True, "moderate"
elif psi >= 0.1:
return True, "minor"
else:
return is_drifted, "none"
def generate_recommendations(feature_scores: list[FeatureDrift], overall_drifted: bool) -> list[str]:
"""Generate actionable recommendations based on drift analysis"""
recommendations = []
severe_features = [f.feature for f in feature_scores if f.drift_type == "severe"]
moderate_features = [f.feature for f in feature_scores if f.drift_type == "moderate"]
minor_features = [f.feature for f in feature_scores if f.drift_type == "minor"]
if severe_features:
recommendations.append(f"🚨 CRITICAL: Severe drift detected in {len(severe_features)} feature(s): {', '.join(severe_features[:5])}. Immediate model retraining recommended.")
recommendations.append("Consider rolling back to a previous model version if performance degradation is observed.")
if moderate_features:
recommendations.append(f"⚠️ WARNING: Moderate drift in {len(moderate_features)} feature(s): {', '.join(moderate_features[:5])}. Schedule model retraining within 1-2 weeks.")
recommendations.append("Monitor model performance metrics closely for these features.")
if minor_features:
recommendations.append(f" INFO: Minor drift detected in {len(minor_features)} feature(s). Continue monitoring.")
if overall_drifted:
recommendations.append("Update baseline distributions after addressing drift to reset monitoring.")
recommendations.append("Investigate data pipeline changes that may have caused distribution shifts.")
recommendations.append("Consider feature engineering adjustments for drifted features.")
else:
recommendations.append("✅ No significant drift detected. Model distributions are stable.")
recommendations.append("Continue regular monitoring at current frequency.")
return recommendations
@router.post("/baseline")
async def upload_baseline(
file: UploadFile = File(...),
name: Optional[str] = Form(None)
):
"""Upload baseline distribution for comparison"""
try:
conn, table_name = await read_to_duckdb(file)
numeric_cols = get_numeric_columns(conn, table_name)
if not numeric_cols:
raise HTTPException(status_code=400, detail="No numeric columns found in the dataset")
# Generate baseline ID
baseline_id = hashlib.md5(f"{file.filename}_{datetime.now().isoformat()}".encode()).hexdigest()[:12]
# Store baseline statistics and raw values for each column
baseline_data = {
"id": baseline_id,
"name": name or file.filename,
"filename": file.filename,
"created_at": datetime.now().isoformat(),
"row_count": conn.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0],
"columns": {},
"values": {}
}
for col in numeric_cols:
baseline_data["columns"][col] = get_column_stats(conn, table_name, col)
# Store actual values for PSI/KS calculation
values = conn.execute(f'SELECT "{col}"::DOUBLE FROM {table_name} WHERE "{col}" IS NOT NULL').fetchall()
baseline_data["values"][col] = np.array([v[0] for v in values])
baselines_store[baseline_id] = baseline_data
conn.close()
return {
"message": "Baseline uploaded successfully",
"baseline_id": baseline_id,
"name": baseline_data["name"],
"filename": file.filename,
"row_count": baseline_data["row_count"],
"numeric_columns": numeric_cols,
"column_stats": baseline_data["columns"],
"engine": "DuckDB"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing baseline file: {str(e)}")
@router.get("/baselines")
async def list_baselines():
"""List all stored baselines"""
return {
"baselines": [
{
"id": b["id"],
"name": b["name"],
"filename": b["filename"],
"created_at": b["created_at"],
"row_count": b["row_count"],
"columns": list(b["columns"].keys())
}
for b in baselines_store.values()
]
}
@router.delete("/baseline/{baseline_id}")
async def delete_baseline(baseline_id: str):
"""Delete a stored baseline"""
if baseline_id not in baselines_store:
raise HTTPException(status_code=404, detail="Baseline not found")
del baselines_store[baseline_id]
return {"message": "Baseline deleted", "baseline_id": baseline_id}
@router.post("/analyze")
async def analyze_drift(
production_file: UploadFile = File(...),
baseline_id: str = Form(...)
):
"""Analyze production data for drift against baseline"""
if baseline_id not in baselines_store:
raise HTTPException(status_code=404, detail=f"Baseline '{baseline_id}' not found. Upload a baseline first.")
try:
baseline = baselines_store[baseline_id]
conn, table_name = await read_to_duckdb(production_file)
numeric_cols = get_numeric_columns(conn, table_name)
common_cols = [col for col in numeric_cols if col in baseline["columns"]]
if not common_cols:
raise HTTPException(status_code=400, detail="No matching numeric columns found between production data and baseline")
feature_scores = []
total_psi = 0.0
drifted_count = 0
for col in common_cols:
# Get current values
current_values = conn.execute(f'SELECT "{col}"::DOUBLE FROM {table_name} WHERE "{col}" IS NOT NULL').fetchall()
current_arr = np.array([v[0] for v in current_values])
baseline_arr = baseline["values"][col]
# Calculate drift metrics
psi = calculate_psi(baseline_arr, current_arr)
ks_stat, ks_pvalue = calculate_ks_statistic(baseline_arr, current_arr)
# Classify drift
is_drifted, drift_type = classify_drift(psi, ks_pvalue, current_thresholds.psi_threshold, current_thresholds.ks_threshold)
if is_drifted:
drifted_count += 1
total_psi += psi
feature_scores.append(FeatureDrift(
feature=col,
psi_score=round(psi, 4),
ks_statistic=round(ks_stat, 4),
ks_pvalue=round(ks_pvalue, 4),
is_drifted=is_drifted,
drift_type=drift_type,
baseline_stats=baseline["columns"][col],
current_stats=get_column_stats(conn, table_name, col)
))
conn.close()
# Calculate overall drift
avg_psi = total_psi / len(common_cols) if common_cols else 0
overall_drifted = drifted_count > 0
# Determine severity
severe_count = len([f for f in feature_scores if f.drift_type == "severe"])
moderate_count = len([f for f in feature_scores if f.drift_type == "moderate"])
if severe_count > 0:
drift_severity = "severe"
elif moderate_count > 0:
drift_severity = "moderate"
elif drifted_count > 0:
drift_severity = "minor"
else:
drift_severity = "none"
# Generate recommendations
recommendations = generate_recommendations(feature_scores, overall_drifted)
# Create result
result = DriftResult(
is_drifted=overall_drifted,
overall_score=round(avg_psi, 4),
drift_severity=drift_severity,
drifted_features=drifted_count,
total_features=len(common_cols),
feature_scores=feature_scores,
method="PSI + Kolmogorov-Smirnov",
recommendations=recommendations,
timestamp=datetime.now().isoformat(),
engine="DuckDB"
)
# Store in history
drift_history.append({
"baseline_id": baseline_id,
"production_file": production_file.filename,
"timestamp": result.timestamp,
"is_drifted": result.is_drifted,
"overall_score": result.overall_score,
"drift_severity": result.drift_severity,
"drifted_features": result.drifted_features,
"total_features": result.total_features
})
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error analyzing drift: {str(e)}")
@router.post("/compare-files")
async def compare_two_files(
baseline_file: UploadFile = File(...),
production_file: UploadFile = File(...)
):
"""Compare two files directly without storing baseline"""
try:
# Load both files
baseline_conn, baseline_table = await read_to_duckdb(baseline_file)
# Need to reset file position for second read
production_content = await production_file.read()
# Create production connection
prod_conn = duckdb.connect(":memory:")
filename = production_file.filename.lower() if production_file.filename else "file.csv"
suffix = '.csv' if filename.endswith('.csv') else '.json' if filename.endswith('.json') else '.csv'
with tempfile.NamedTemporaryFile(mode='wb', suffix=suffix, delete=False) as tmp:
tmp.write(production_content)
tmp_path = tmp.name
try:
if filename.endswith('.csv'):
prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
elif filename.endswith('.json'):
prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_json_auto('{tmp_path}')")
else:
prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
finally:
os.unlink(tmp_path)
prod_table = "data"
# Get common numeric columns
baseline_cols = get_numeric_columns(baseline_conn, baseline_table)
prod_cols = get_numeric_columns(prod_conn, prod_table)
common_cols = list(set(baseline_cols) & set(prod_cols))
if not common_cols:
raise HTTPException(status_code=400, detail="No matching numeric columns found between files")
feature_scores = []
total_psi = 0.0
drifted_count = 0
for col in common_cols:
# Get values from both files
baseline_values = baseline_conn.execute(f'SELECT "{col}"::DOUBLE FROM {baseline_table} WHERE "{col}" IS NOT NULL').fetchall()
prod_values = prod_conn.execute(f'SELECT "{col}"::DOUBLE FROM {prod_table} WHERE "{col}" IS NOT NULL').fetchall()
baseline_arr = np.array([v[0] for v in baseline_values])
prod_arr = np.array([v[0] for v in prod_values])
# Calculate drift metrics
psi = calculate_psi(baseline_arr, prod_arr)
ks_stat, ks_pvalue = calculate_ks_statistic(baseline_arr, prod_arr)
# Classify drift
is_drifted, drift_type = classify_drift(psi, ks_pvalue, current_thresholds.psi_threshold, current_thresholds.ks_threshold)
if is_drifted:
drifted_count += 1
total_psi += psi
feature_scores.append(FeatureDrift(
feature=col,
psi_score=round(psi, 4),
ks_statistic=round(ks_stat, 4),
ks_pvalue=round(ks_pvalue, 4),
is_drifted=is_drifted,
drift_type=drift_type,
baseline_stats=get_column_stats(baseline_conn, baseline_table, col),
current_stats=get_column_stats(prod_conn, prod_table, col)
))
baseline_conn.close()
prod_conn.close()
# Calculate overall drift
avg_psi = total_psi / len(common_cols) if common_cols else 0
overall_drifted = drifted_count > 0
# Determine severity
severe_count = len([f for f in feature_scores if f.drift_type == "severe"])
moderate_count = len([f for f in feature_scores if f.drift_type == "moderate"])
if severe_count > 0:
drift_severity = "severe"
elif moderate_count > 0:
drift_severity = "moderate"
elif drifted_count > 0:
drift_severity = "minor"
else:
drift_severity = "none"
recommendations = generate_recommendations(feature_scores, overall_drifted)
return DriftResult(
is_drifted=overall_drifted,
overall_score=round(avg_psi, 4),
drift_severity=drift_severity,
drifted_features=drifted_count,
total_features=len(common_cols),
feature_scores=feature_scores,
method="PSI + Kolmogorov-Smirnov",
recommendations=recommendations,
timestamp=datetime.now().isoformat(),
engine="DuckDB"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error comparing files: {str(e)}")
@router.get("/history")
async def get_drift_history(limit: int = 100):
"""Get historical drift analysis results"""
return {
"history": drift_history[-limit:],
"total_analyses": len(drift_history)
}
@router.put("/thresholds")
async def update_thresholds(thresholds: DriftThresholds):
"""Update drift detection thresholds"""
global current_thresholds
current_thresholds = thresholds
return {
"message": "Thresholds updated",
"thresholds": {
"psi_threshold": current_thresholds.psi_threshold,
"ks_threshold": current_thresholds.ks_threshold,
"alert_enabled": current_thresholds.alert_enabled
}
}
@router.get("/thresholds")
async def get_thresholds():
"""Get current drift detection thresholds"""
return {
"psi_threshold": current_thresholds.psi_threshold,
"ks_threshold": current_thresholds.ks_threshold,
"alert_enabled": current_thresholds.alert_enabled,
"psi_interpretation": {
"low": "PSI < 0.1 - No significant change",
"moderate": "0.1 <= PSI < 0.2 - Moderate change, monitoring needed",
"high": "PSI >= 0.2 - Significant change, action required"
}
}