590 lines
22 KiB
Python
590 lines
22 KiB
Python
|
|
"""Model Drift Monitor Router - Detect distribution shifts in ML features"""
|
|||
|
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Form
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from typing import Optional
|
|||
|
|
import numpy as np
|
|||
|
|
import duckdb
|
|||
|
|
import tempfile
|
|||
|
|
import os
|
|||
|
|
import json
|
|||
|
|
from datetime import datetime
|
|||
|
|
import hashlib
|
|||
|
|
|
|||
|
|
router = APIRouter()
|
|||
|
|
|
|||
|
|
# In-memory storage for baselines and history
|
|||
|
|
baselines_store: dict = {}
|
|||
|
|
drift_history: list = []
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DriftThresholds(BaseModel):
|
|||
|
|
psi_threshold: float = 0.2 # PSI > 0.2 indicates significant drift
|
|||
|
|
ks_threshold: float = 0.05 # KS p-value < 0.05 indicates drift
|
|||
|
|
alert_enabled: bool = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FeatureDrift(BaseModel):
|
|||
|
|
feature: str
|
|||
|
|
psi_score: float
|
|||
|
|
ks_statistic: float
|
|||
|
|
ks_pvalue: float
|
|||
|
|
is_drifted: bool
|
|||
|
|
drift_type: str # "none", "minor", "moderate", "severe"
|
|||
|
|
baseline_stats: dict
|
|||
|
|
current_stats: dict
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DriftResult(BaseModel):
|
|||
|
|
is_drifted: bool
|
|||
|
|
overall_score: float
|
|||
|
|
drift_severity: str
|
|||
|
|
drifted_features: int
|
|||
|
|
total_features: int
|
|||
|
|
feature_scores: list[FeatureDrift]
|
|||
|
|
method: str
|
|||
|
|
recommendations: list[str]
|
|||
|
|
timestamp: str
|
|||
|
|
engine: str = "DuckDB"
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Current thresholds (in-memory, could be persisted)
|
|||
|
|
current_thresholds = DriftThresholds()
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def read_to_duckdb(file: UploadFile) -> tuple[duckdb.DuckDBPyConnection, str]:
|
|||
|
|
"""Read uploaded file into DuckDB in-memory database"""
|
|||
|
|
content = await file.read()
|
|||
|
|
filename = file.filename.lower() if file.filename else "file.csv"
|
|||
|
|
|
|||
|
|
conn = duckdb.connect(":memory:")
|
|||
|
|
|
|||
|
|
# Write to temp file for DuckDB to read
|
|||
|
|
suffix = '.csv' if filename.endswith('.csv') else '.json' if filename.endswith('.json') else '.csv'
|
|||
|
|
with tempfile.NamedTemporaryFile(mode='wb', suffix=suffix, delete=False) as tmp:
|
|||
|
|
tmp.write(content)
|
|||
|
|
tmp_path = tmp.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if filename.endswith('.csv'):
|
|||
|
|
conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
|
|||
|
|
elif filename.endswith('.json'):
|
|||
|
|
conn.execute(f"CREATE TABLE data AS SELECT * FROM read_json_auto('{tmp_path}')")
|
|||
|
|
else:
|
|||
|
|
conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
|
|||
|
|
finally:
|
|||
|
|
os.unlink(tmp_path)
|
|||
|
|
|
|||
|
|
return conn, "data"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_numeric_columns(conn: duckdb.DuckDBPyConnection, table_name: str) -> list[str]:
|
|||
|
|
"""Get list of numeric columns from table"""
|
|||
|
|
schema = conn.execute(f"DESCRIBE {table_name}").fetchall()
|
|||
|
|
numeric_types = ['INTEGER', 'BIGINT', 'DOUBLE', 'FLOAT', 'DECIMAL', 'REAL', 'SMALLINT', 'TINYINT', 'HUGEINT']
|
|||
|
|
return [col[0] for col in schema if any(t in col[1].upper() for t in numeric_types)]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def calculate_psi(baseline_values: np.ndarray, current_values: np.ndarray, bins: int = 10) -> float:
|
|||
|
|
"""
|
|||
|
|
Calculate Population Stability Index (PSI)
|
|||
|
|
PSI < 0.1: No significant change
|
|||
|
|
0.1 <= PSI < 0.2: Moderate change, monitoring needed
|
|||
|
|
PSI >= 0.2: Significant change, action required
|
|||
|
|
"""
|
|||
|
|
# Remove NaN values
|
|||
|
|
baseline_clean = baseline_values[~np.isnan(baseline_values)]
|
|||
|
|
current_clean = current_values[~np.isnan(current_values)]
|
|||
|
|
|
|||
|
|
if len(baseline_clean) == 0 or len(current_clean) == 0:
|
|||
|
|
return 0.0
|
|||
|
|
|
|||
|
|
# Create bins based on baseline distribution
|
|||
|
|
min_val = min(baseline_clean.min(), current_clean.min())
|
|||
|
|
max_val = max(baseline_clean.max(), current_clean.max())
|
|||
|
|
|
|||
|
|
if min_val == max_val:
|
|||
|
|
return 0.0
|
|||
|
|
|
|||
|
|
bin_edges = np.linspace(min_val, max_val, bins + 1)
|
|||
|
|
|
|||
|
|
# Calculate proportions for each bin
|
|||
|
|
baseline_counts, _ = np.histogram(baseline_clean, bins=bin_edges)
|
|||
|
|
current_counts, _ = np.histogram(current_clean, bins=bin_edges)
|
|||
|
|
|
|||
|
|
# Convert to proportions (add small epsilon to avoid division by zero)
|
|||
|
|
epsilon = 1e-6
|
|||
|
|
baseline_prop = (baseline_counts + epsilon) / (len(baseline_clean) + epsilon * bins)
|
|||
|
|
current_prop = (current_counts + epsilon) / (len(current_clean) + epsilon * bins)
|
|||
|
|
|
|||
|
|
# Calculate PSI
|
|||
|
|
psi = np.sum((current_prop - baseline_prop) * np.log(current_prop / baseline_prop))
|
|||
|
|
|
|||
|
|
return float(psi)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def calculate_ks_statistic(baseline_values: np.ndarray, current_values: np.ndarray) -> tuple[float, float]:
|
|||
|
|
"""
|
|||
|
|
Calculate Kolmogorov-Smirnov statistic and approximate p-value
|
|||
|
|
"""
|
|||
|
|
# Remove NaN values
|
|||
|
|
baseline_clean = baseline_values[~np.isnan(baseline_values)]
|
|||
|
|
current_clean = current_values[~np.isnan(current_values)]
|
|||
|
|
|
|||
|
|
if len(baseline_clean) == 0 or len(current_clean) == 0:
|
|||
|
|
return 0.0, 1.0
|
|||
|
|
|
|||
|
|
# Sort both arrays
|
|||
|
|
baseline_sorted = np.sort(baseline_clean)
|
|||
|
|
current_sorted = np.sort(current_clean)
|
|||
|
|
|
|||
|
|
# Create combined array of all values
|
|||
|
|
all_values = np.concatenate([baseline_sorted, current_sorted])
|
|||
|
|
all_values = np.sort(np.unique(all_values))
|
|||
|
|
|
|||
|
|
# Calculate CDFs
|
|||
|
|
baseline_cdf = np.searchsorted(baseline_sorted, all_values, side='right') / len(baseline_sorted)
|
|||
|
|
current_cdf = np.searchsorted(current_sorted, all_values, side='right') / len(current_sorted)
|
|||
|
|
|
|||
|
|
# KS statistic is the maximum difference
|
|||
|
|
ks_stat = float(np.max(np.abs(baseline_cdf - current_cdf)))
|
|||
|
|
|
|||
|
|
# Approximate p-value using asymptotic formula
|
|||
|
|
n1, n2 = len(baseline_clean), len(current_clean)
|
|||
|
|
en = np.sqrt(n1 * n2 / (n1 + n2))
|
|||
|
|
|
|||
|
|
# Kolmogorov distribution approximation
|
|||
|
|
lambda_val = (en + 0.12 + 0.11 / en) * ks_stat
|
|||
|
|
|
|||
|
|
# Two-sided p-value approximation
|
|||
|
|
if lambda_val < 0.001:
|
|||
|
|
p_value = 1.0
|
|||
|
|
else:
|
|||
|
|
# Approximation using exponential terms
|
|||
|
|
j = np.arange(1, 101)
|
|||
|
|
p_value = 2 * np.sum((-1) ** (j - 1) * np.exp(-2 * j ** 2 * lambda_val ** 2))
|
|||
|
|
p_value = max(0.0, min(1.0, p_value))
|
|||
|
|
|
|||
|
|
return ks_stat, float(p_value)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_column_stats(conn: duckdb.DuckDBPyConnection, table_name: str, column: str) -> dict:
|
|||
|
|
"""Get statistics for a column using DuckDB"""
|
|||
|
|
try:
|
|||
|
|
stats = conn.execute(f'''
|
|||
|
|
SELECT
|
|||
|
|
COUNT(*) as count,
|
|||
|
|
COUNT("{column}") as non_null,
|
|||
|
|
AVG("{column}"::DOUBLE) as mean,
|
|||
|
|
STDDEV("{column}"::DOUBLE) as std,
|
|||
|
|
MIN("{column}"::DOUBLE) as min,
|
|||
|
|
MAX("{column}"::DOUBLE) as max,
|
|||
|
|
PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY "{column}") as q1,
|
|||
|
|
PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY "{column}") as median,
|
|||
|
|
PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{column}") as q3
|
|||
|
|
FROM {table_name}
|
|||
|
|
''').fetchone()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"count": stats[0],
|
|||
|
|
"non_null": stats[1],
|
|||
|
|
"mean": float(stats[2]) if stats[2] is not None else None,
|
|||
|
|
"std": float(stats[3]) if stats[3] is not None else None,
|
|||
|
|
"min": float(stats[4]) if stats[4] is not None else None,
|
|||
|
|
"max": float(stats[5]) if stats[5] is not None else None,
|
|||
|
|
"q1": float(stats[6]) if stats[6] is not None else None,
|
|||
|
|
"median": float(stats[7]) if stats[7] is not None else None,
|
|||
|
|
"q3": float(stats[8]) if stats[8] is not None else None
|
|||
|
|
}
|
|||
|
|
except Exception:
|
|||
|
|
return {"count": 0, "non_null": 0}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def classify_drift(psi: float, ks_pvalue: float, psi_threshold: float, ks_threshold: float) -> tuple[bool, str]:
|
|||
|
|
"""Classify drift severity based on PSI and KS test"""
|
|||
|
|
is_drifted = psi >= psi_threshold or ks_pvalue < ks_threshold
|
|||
|
|
|
|||
|
|
if psi >= 0.25 or ks_pvalue < 0.01:
|
|||
|
|
return True, "severe"
|
|||
|
|
elif psi >= 0.2 or ks_pvalue < 0.05:
|
|||
|
|
return True, "moderate"
|
|||
|
|
elif psi >= 0.1:
|
|||
|
|
return True, "minor"
|
|||
|
|
else:
|
|||
|
|
return is_drifted, "none"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_recommendations(feature_scores: list[FeatureDrift], overall_drifted: bool) -> list[str]:
|
|||
|
|
"""Generate actionable recommendations based on drift analysis"""
|
|||
|
|
recommendations = []
|
|||
|
|
|
|||
|
|
severe_features = [f.feature for f in feature_scores if f.drift_type == "severe"]
|
|||
|
|
moderate_features = [f.feature for f in feature_scores if f.drift_type == "moderate"]
|
|||
|
|
minor_features = [f.feature for f in feature_scores if f.drift_type == "minor"]
|
|||
|
|
|
|||
|
|
if severe_features:
|
|||
|
|
recommendations.append(f"🚨 CRITICAL: Severe drift detected in {len(severe_features)} feature(s): {', '.join(severe_features[:5])}. Immediate model retraining recommended.")
|
|||
|
|
recommendations.append("Consider rolling back to a previous model version if performance degradation is observed.")
|
|||
|
|
|
|||
|
|
if moderate_features:
|
|||
|
|
recommendations.append(f"⚠️ WARNING: Moderate drift in {len(moderate_features)} feature(s): {', '.join(moderate_features[:5])}. Schedule model retraining within 1-2 weeks.")
|
|||
|
|
recommendations.append("Monitor model performance metrics closely for these features.")
|
|||
|
|
|
|||
|
|
if minor_features:
|
|||
|
|
recommendations.append(f"ℹ️ INFO: Minor drift detected in {len(minor_features)} feature(s). Continue monitoring.")
|
|||
|
|
|
|||
|
|
if overall_drifted:
|
|||
|
|
recommendations.append("Update baseline distributions after addressing drift to reset monitoring.")
|
|||
|
|
recommendations.append("Investigate data pipeline changes that may have caused distribution shifts.")
|
|||
|
|
recommendations.append("Consider feature engineering adjustments for drifted features.")
|
|||
|
|
else:
|
|||
|
|
recommendations.append("✅ No significant drift detected. Model distributions are stable.")
|
|||
|
|
recommendations.append("Continue regular monitoring at current frequency.")
|
|||
|
|
|
|||
|
|
return recommendations
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/baseline")
|
|||
|
|
async def upload_baseline(
|
|||
|
|
file: UploadFile = File(...),
|
|||
|
|
name: Optional[str] = Form(None)
|
|||
|
|
):
|
|||
|
|
"""Upload baseline distribution for comparison"""
|
|||
|
|
try:
|
|||
|
|
conn, table_name = await read_to_duckdb(file)
|
|||
|
|
numeric_cols = get_numeric_columns(conn, table_name)
|
|||
|
|
|
|||
|
|
if not numeric_cols:
|
|||
|
|
raise HTTPException(status_code=400, detail="No numeric columns found in the dataset")
|
|||
|
|
|
|||
|
|
# Generate baseline ID
|
|||
|
|
baseline_id = hashlib.md5(f"{file.filename}_{datetime.now().isoformat()}".encode()).hexdigest()[:12]
|
|||
|
|
|
|||
|
|
# Store baseline statistics and raw values for each column
|
|||
|
|
baseline_data = {
|
|||
|
|
"id": baseline_id,
|
|||
|
|
"name": name or file.filename,
|
|||
|
|
"filename": file.filename,
|
|||
|
|
"created_at": datetime.now().isoformat(),
|
|||
|
|
"row_count": conn.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0],
|
|||
|
|
"columns": {},
|
|||
|
|
"values": {}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for col in numeric_cols:
|
|||
|
|
baseline_data["columns"][col] = get_column_stats(conn, table_name, col)
|
|||
|
|
# Store actual values for PSI/KS calculation
|
|||
|
|
values = conn.execute(f'SELECT "{col}"::DOUBLE FROM {table_name} WHERE "{col}" IS NOT NULL').fetchall()
|
|||
|
|
baseline_data["values"][col] = np.array([v[0] for v in values])
|
|||
|
|
|
|||
|
|
baselines_store[baseline_id] = baseline_data
|
|||
|
|
|
|||
|
|
conn.close()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"message": "Baseline uploaded successfully",
|
|||
|
|
"baseline_id": baseline_id,
|
|||
|
|
"name": baseline_data["name"],
|
|||
|
|
"filename": file.filename,
|
|||
|
|
"row_count": baseline_data["row_count"],
|
|||
|
|
"numeric_columns": numeric_cols,
|
|||
|
|
"column_stats": baseline_data["columns"],
|
|||
|
|
"engine": "DuckDB"
|
|||
|
|
}
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"Error processing baseline file: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/baselines")
|
|||
|
|
async def list_baselines():
|
|||
|
|
"""List all stored baselines"""
|
|||
|
|
return {
|
|||
|
|
"baselines": [
|
|||
|
|
{
|
|||
|
|
"id": b["id"],
|
|||
|
|
"name": b["name"],
|
|||
|
|
"filename": b["filename"],
|
|||
|
|
"created_at": b["created_at"],
|
|||
|
|
"row_count": b["row_count"],
|
|||
|
|
"columns": list(b["columns"].keys())
|
|||
|
|
}
|
|||
|
|
for b in baselines_store.values()
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/baseline/{baseline_id}")
|
|||
|
|
async def delete_baseline(baseline_id: str):
|
|||
|
|
"""Delete a stored baseline"""
|
|||
|
|
if baseline_id not in baselines_store:
|
|||
|
|
raise HTTPException(status_code=404, detail="Baseline not found")
|
|||
|
|
|
|||
|
|
del baselines_store[baseline_id]
|
|||
|
|
return {"message": "Baseline deleted", "baseline_id": baseline_id}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/analyze")
|
|||
|
|
async def analyze_drift(
|
|||
|
|
production_file: UploadFile = File(...),
|
|||
|
|
baseline_id: str = Form(...)
|
|||
|
|
):
|
|||
|
|
"""Analyze production data for drift against baseline"""
|
|||
|
|
if baseline_id not in baselines_store:
|
|||
|
|
raise HTTPException(status_code=404, detail=f"Baseline '{baseline_id}' not found. Upload a baseline first.")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
baseline = baselines_store[baseline_id]
|
|||
|
|
conn, table_name = await read_to_duckdb(production_file)
|
|||
|
|
|
|||
|
|
numeric_cols = get_numeric_columns(conn, table_name)
|
|||
|
|
common_cols = [col for col in numeric_cols if col in baseline["columns"]]
|
|||
|
|
|
|||
|
|
if not common_cols:
|
|||
|
|
raise HTTPException(status_code=400, detail="No matching numeric columns found between production data and baseline")
|
|||
|
|
|
|||
|
|
feature_scores = []
|
|||
|
|
total_psi = 0.0
|
|||
|
|
drifted_count = 0
|
|||
|
|
|
|||
|
|
for col in common_cols:
|
|||
|
|
# Get current values
|
|||
|
|
current_values = conn.execute(f'SELECT "{col}"::DOUBLE FROM {table_name} WHERE "{col}" IS NOT NULL').fetchall()
|
|||
|
|
current_arr = np.array([v[0] for v in current_values])
|
|||
|
|
baseline_arr = baseline["values"][col]
|
|||
|
|
|
|||
|
|
# Calculate drift metrics
|
|||
|
|
psi = calculate_psi(baseline_arr, current_arr)
|
|||
|
|
ks_stat, ks_pvalue = calculate_ks_statistic(baseline_arr, current_arr)
|
|||
|
|
|
|||
|
|
# Classify drift
|
|||
|
|
is_drifted, drift_type = classify_drift(psi, ks_pvalue, current_thresholds.psi_threshold, current_thresholds.ks_threshold)
|
|||
|
|
|
|||
|
|
if is_drifted:
|
|||
|
|
drifted_count += 1
|
|||
|
|
|
|||
|
|
total_psi += psi
|
|||
|
|
|
|||
|
|
feature_scores.append(FeatureDrift(
|
|||
|
|
feature=col,
|
|||
|
|
psi_score=round(psi, 4),
|
|||
|
|
ks_statistic=round(ks_stat, 4),
|
|||
|
|
ks_pvalue=round(ks_pvalue, 4),
|
|||
|
|
is_drifted=is_drifted,
|
|||
|
|
drift_type=drift_type,
|
|||
|
|
baseline_stats=baseline["columns"][col],
|
|||
|
|
current_stats=get_column_stats(conn, table_name, col)
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
conn.close()
|
|||
|
|
|
|||
|
|
# Calculate overall drift
|
|||
|
|
avg_psi = total_psi / len(common_cols) if common_cols else 0
|
|||
|
|
overall_drifted = drifted_count > 0
|
|||
|
|
|
|||
|
|
# Determine severity
|
|||
|
|
severe_count = len([f for f in feature_scores if f.drift_type == "severe"])
|
|||
|
|
moderate_count = len([f for f in feature_scores if f.drift_type == "moderate"])
|
|||
|
|
|
|||
|
|
if severe_count > 0:
|
|||
|
|
drift_severity = "severe"
|
|||
|
|
elif moderate_count > 0:
|
|||
|
|
drift_severity = "moderate"
|
|||
|
|
elif drifted_count > 0:
|
|||
|
|
drift_severity = "minor"
|
|||
|
|
else:
|
|||
|
|
drift_severity = "none"
|
|||
|
|
|
|||
|
|
# Generate recommendations
|
|||
|
|
recommendations = generate_recommendations(feature_scores, overall_drifted)
|
|||
|
|
|
|||
|
|
# Create result
|
|||
|
|
result = DriftResult(
|
|||
|
|
is_drifted=overall_drifted,
|
|||
|
|
overall_score=round(avg_psi, 4),
|
|||
|
|
drift_severity=drift_severity,
|
|||
|
|
drifted_features=drifted_count,
|
|||
|
|
total_features=len(common_cols),
|
|||
|
|
feature_scores=feature_scores,
|
|||
|
|
method="PSI + Kolmogorov-Smirnov",
|
|||
|
|
recommendations=recommendations,
|
|||
|
|
timestamp=datetime.now().isoformat(),
|
|||
|
|
engine="DuckDB"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Store in history
|
|||
|
|
drift_history.append({
|
|||
|
|
"baseline_id": baseline_id,
|
|||
|
|
"production_file": production_file.filename,
|
|||
|
|
"timestamp": result.timestamp,
|
|||
|
|
"is_drifted": result.is_drifted,
|
|||
|
|
"overall_score": result.overall_score,
|
|||
|
|
"drift_severity": result.drift_severity,
|
|||
|
|
"drifted_features": result.drifted_features,
|
|||
|
|
"total_features": result.total_features
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"Error analyzing drift: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/compare-files")
|
|||
|
|
async def compare_two_files(
|
|||
|
|
baseline_file: UploadFile = File(...),
|
|||
|
|
production_file: UploadFile = File(...)
|
|||
|
|
):
|
|||
|
|
"""Compare two files directly without storing baseline"""
|
|||
|
|
try:
|
|||
|
|
# Load both files
|
|||
|
|
baseline_conn, baseline_table = await read_to_duckdb(baseline_file)
|
|||
|
|
|
|||
|
|
# Need to reset file position for second read
|
|||
|
|
production_content = await production_file.read()
|
|||
|
|
|
|||
|
|
# Create production connection
|
|||
|
|
prod_conn = duckdb.connect(":memory:")
|
|||
|
|
filename = production_file.filename.lower() if production_file.filename else "file.csv"
|
|||
|
|
suffix = '.csv' if filename.endswith('.csv') else '.json' if filename.endswith('.json') else '.csv'
|
|||
|
|
|
|||
|
|
with tempfile.NamedTemporaryFile(mode='wb', suffix=suffix, delete=False) as tmp:
|
|||
|
|
tmp.write(production_content)
|
|||
|
|
tmp_path = tmp.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if filename.endswith('.csv'):
|
|||
|
|
prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
|
|||
|
|
elif filename.endswith('.json'):
|
|||
|
|
prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_json_auto('{tmp_path}')")
|
|||
|
|
else:
|
|||
|
|
prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
|
|||
|
|
finally:
|
|||
|
|
os.unlink(tmp_path)
|
|||
|
|
|
|||
|
|
prod_table = "data"
|
|||
|
|
|
|||
|
|
# Get common numeric columns
|
|||
|
|
baseline_cols = get_numeric_columns(baseline_conn, baseline_table)
|
|||
|
|
prod_cols = get_numeric_columns(prod_conn, prod_table)
|
|||
|
|
common_cols = list(set(baseline_cols) & set(prod_cols))
|
|||
|
|
|
|||
|
|
if not common_cols:
|
|||
|
|
raise HTTPException(status_code=400, detail="No matching numeric columns found between files")
|
|||
|
|
|
|||
|
|
feature_scores = []
|
|||
|
|
total_psi = 0.0
|
|||
|
|
drifted_count = 0
|
|||
|
|
|
|||
|
|
for col in common_cols:
|
|||
|
|
# Get values from both files
|
|||
|
|
baseline_values = baseline_conn.execute(f'SELECT "{col}"::DOUBLE FROM {baseline_table} WHERE "{col}" IS NOT NULL').fetchall()
|
|||
|
|
prod_values = prod_conn.execute(f'SELECT "{col}"::DOUBLE FROM {prod_table} WHERE "{col}" IS NOT NULL').fetchall()
|
|||
|
|
|
|||
|
|
baseline_arr = np.array([v[0] for v in baseline_values])
|
|||
|
|
prod_arr = np.array([v[0] for v in prod_values])
|
|||
|
|
|
|||
|
|
# Calculate drift metrics
|
|||
|
|
psi = calculate_psi(baseline_arr, prod_arr)
|
|||
|
|
ks_stat, ks_pvalue = calculate_ks_statistic(baseline_arr, prod_arr)
|
|||
|
|
|
|||
|
|
# Classify drift
|
|||
|
|
is_drifted, drift_type = classify_drift(psi, ks_pvalue, current_thresholds.psi_threshold, current_thresholds.ks_threshold)
|
|||
|
|
|
|||
|
|
if is_drifted:
|
|||
|
|
drifted_count += 1
|
|||
|
|
|
|||
|
|
total_psi += psi
|
|||
|
|
|
|||
|
|
feature_scores.append(FeatureDrift(
|
|||
|
|
feature=col,
|
|||
|
|
psi_score=round(psi, 4),
|
|||
|
|
ks_statistic=round(ks_stat, 4),
|
|||
|
|
ks_pvalue=round(ks_pvalue, 4),
|
|||
|
|
is_drifted=is_drifted,
|
|||
|
|
drift_type=drift_type,
|
|||
|
|
baseline_stats=get_column_stats(baseline_conn, baseline_table, col),
|
|||
|
|
current_stats=get_column_stats(prod_conn, prod_table, col)
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
baseline_conn.close()
|
|||
|
|
prod_conn.close()
|
|||
|
|
|
|||
|
|
# Calculate overall drift
|
|||
|
|
avg_psi = total_psi / len(common_cols) if common_cols else 0
|
|||
|
|
overall_drifted = drifted_count > 0
|
|||
|
|
|
|||
|
|
# Determine severity
|
|||
|
|
severe_count = len([f for f in feature_scores if f.drift_type == "severe"])
|
|||
|
|
moderate_count = len([f for f in feature_scores if f.drift_type == "moderate"])
|
|||
|
|
|
|||
|
|
if severe_count > 0:
|
|||
|
|
drift_severity = "severe"
|
|||
|
|
elif moderate_count > 0:
|
|||
|
|
drift_severity = "moderate"
|
|||
|
|
elif drifted_count > 0:
|
|||
|
|
drift_severity = "minor"
|
|||
|
|
else:
|
|||
|
|
drift_severity = "none"
|
|||
|
|
|
|||
|
|
recommendations = generate_recommendations(feature_scores, overall_drifted)
|
|||
|
|
|
|||
|
|
return DriftResult(
|
|||
|
|
is_drifted=overall_drifted,
|
|||
|
|
overall_score=round(avg_psi, 4),
|
|||
|
|
drift_severity=drift_severity,
|
|||
|
|
drifted_features=drifted_count,
|
|||
|
|
total_features=len(common_cols),
|
|||
|
|
feature_scores=feature_scores,
|
|||
|
|
method="PSI + Kolmogorov-Smirnov",
|
|||
|
|
recommendations=recommendations,
|
|||
|
|
timestamp=datetime.now().isoformat(),
|
|||
|
|
engine="DuckDB"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"Error comparing files: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/history")
|
|||
|
|
async def get_drift_history(limit: int = 100):
|
|||
|
|
"""Get historical drift analysis results"""
|
|||
|
|
return {
|
|||
|
|
"history": drift_history[-limit:],
|
|||
|
|
"total_analyses": len(drift_history)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.put("/thresholds")
|
|||
|
|
async def update_thresholds(thresholds: DriftThresholds):
|
|||
|
|
"""Update drift detection thresholds"""
|
|||
|
|
global current_thresholds
|
|||
|
|
current_thresholds = thresholds
|
|||
|
|
return {
|
|||
|
|
"message": "Thresholds updated",
|
|||
|
|
"thresholds": {
|
|||
|
|
"psi_threshold": current_thresholds.psi_threshold,
|
|||
|
|
"ks_threshold": current_thresholds.ks_threshold,
|
|||
|
|
"alert_enabled": current_thresholds.alert_enabled
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/thresholds")
|
|||
|
|
async def get_thresholds():
|
|||
|
|
"""Get current drift detection thresholds"""
|
|||
|
|
return {
|
|||
|
|
"psi_threshold": current_thresholds.psi_threshold,
|
|||
|
|
"ks_threshold": current_thresholds.ks_threshold,
|
|||
|
|
"alert_enabled": current_thresholds.alert_enabled,
|
|||
|
|
"psi_interpretation": {
|
|||
|
|
"low": "PSI < 0.1 - No significant change",
|
|||
|
|
"moderate": "0.1 <= PSI < 0.2 - Moderate change, monitoring needed",
|
|||
|
|
"high": "PSI >= 0.2 - Significant change, action required"
|
|||
|
|
}
|
|||
|
|
}
|