589 lines
22 KiB
Python
589 lines
22 KiB
Python
"""Model Drift Monitor Router - Detect distribution shifts in ML features"""
|
||
from fastapi import APIRouter, UploadFile, File, HTTPException, Form
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
import numpy as np
|
||
import duckdb
|
||
import tempfile
|
||
import os
|
||
import json
|
||
from datetime import datetime
|
||
import hashlib
|
||
|
||
router = APIRouter()
|
||
|
||
# In-memory storage for baselines and history
|
||
baselines_store: dict = {}
|
||
drift_history: list = []
|
||
|
||
|
||
class DriftThresholds(BaseModel):
|
||
psi_threshold: float = 0.2 # PSI > 0.2 indicates significant drift
|
||
ks_threshold: float = 0.05 # KS p-value < 0.05 indicates drift
|
||
alert_enabled: bool = True
|
||
|
||
|
||
class FeatureDrift(BaseModel):
|
||
feature: str
|
||
psi_score: float
|
||
ks_statistic: float
|
||
ks_pvalue: float
|
||
is_drifted: bool
|
||
drift_type: str # "none", "minor", "moderate", "severe"
|
||
baseline_stats: dict
|
||
current_stats: dict
|
||
|
||
|
||
class DriftResult(BaseModel):
|
||
is_drifted: bool
|
||
overall_score: float
|
||
drift_severity: str
|
||
drifted_features: int
|
||
total_features: int
|
||
feature_scores: list[FeatureDrift]
|
||
method: str
|
||
recommendations: list[str]
|
||
timestamp: str
|
||
engine: str = "DuckDB"
|
||
|
||
|
||
# Current thresholds (in-memory, could be persisted)
|
||
current_thresholds = DriftThresholds()
|
||
|
||
|
||
async def read_to_duckdb(file: UploadFile) -> tuple[duckdb.DuckDBPyConnection, str]:
|
||
"""Read uploaded file into DuckDB in-memory database"""
|
||
content = await file.read()
|
||
filename = file.filename.lower() if file.filename else "file.csv"
|
||
|
||
conn = duckdb.connect(":memory:")
|
||
|
||
# Write to temp file for DuckDB to read
|
||
suffix = '.csv' if filename.endswith('.csv') else '.json' if filename.endswith('.json') else '.csv'
|
||
with tempfile.NamedTemporaryFile(mode='wb', suffix=suffix, delete=False) as tmp:
|
||
tmp.write(content)
|
||
tmp_path = tmp.name
|
||
|
||
try:
|
||
if filename.endswith('.csv'):
|
||
conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
|
||
elif filename.endswith('.json'):
|
||
conn.execute(f"CREATE TABLE data AS SELECT * FROM read_json_auto('{tmp_path}')")
|
||
else:
|
||
conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
|
||
finally:
|
||
os.unlink(tmp_path)
|
||
|
||
return conn, "data"
|
||
|
||
|
||
def get_numeric_columns(conn: duckdb.DuckDBPyConnection, table_name: str) -> list[str]:
|
||
"""Get list of numeric columns from table"""
|
||
schema = conn.execute(f"DESCRIBE {table_name}").fetchall()
|
||
numeric_types = ['INTEGER', 'BIGINT', 'DOUBLE', 'FLOAT', 'DECIMAL', 'REAL', 'SMALLINT', 'TINYINT', 'HUGEINT']
|
||
return [col[0] for col in schema if any(t in col[1].upper() for t in numeric_types)]
|
||
|
||
|
||
def calculate_psi(baseline_values: np.ndarray, current_values: np.ndarray, bins: int = 10) -> float:
|
||
"""
|
||
Calculate Population Stability Index (PSI)
|
||
PSI < 0.1: No significant change
|
||
0.1 <= PSI < 0.2: Moderate change, monitoring needed
|
||
PSI >= 0.2: Significant change, action required
|
||
"""
|
||
# Remove NaN values
|
||
baseline_clean = baseline_values[~np.isnan(baseline_values)]
|
||
current_clean = current_values[~np.isnan(current_values)]
|
||
|
||
if len(baseline_clean) == 0 or len(current_clean) == 0:
|
||
return 0.0
|
||
|
||
# Create bins based on baseline distribution
|
||
min_val = min(baseline_clean.min(), current_clean.min())
|
||
max_val = max(baseline_clean.max(), current_clean.max())
|
||
|
||
if min_val == max_val:
|
||
return 0.0
|
||
|
||
bin_edges = np.linspace(min_val, max_val, bins + 1)
|
||
|
||
# Calculate proportions for each bin
|
||
baseline_counts, _ = np.histogram(baseline_clean, bins=bin_edges)
|
||
current_counts, _ = np.histogram(current_clean, bins=bin_edges)
|
||
|
||
# Convert to proportions (add small epsilon to avoid division by zero)
|
||
epsilon = 1e-6
|
||
baseline_prop = (baseline_counts + epsilon) / (len(baseline_clean) + epsilon * bins)
|
||
current_prop = (current_counts + epsilon) / (len(current_clean) + epsilon * bins)
|
||
|
||
# Calculate PSI
|
||
psi = np.sum((current_prop - baseline_prop) * np.log(current_prop / baseline_prop))
|
||
|
||
return float(psi)
|
||
|
||
|
||
def calculate_ks_statistic(baseline_values: np.ndarray, current_values: np.ndarray) -> tuple[float, float]:
|
||
"""
|
||
Calculate Kolmogorov-Smirnov statistic and approximate p-value
|
||
"""
|
||
# Remove NaN values
|
||
baseline_clean = baseline_values[~np.isnan(baseline_values)]
|
||
current_clean = current_values[~np.isnan(current_values)]
|
||
|
||
if len(baseline_clean) == 0 or len(current_clean) == 0:
|
||
return 0.0, 1.0
|
||
|
||
# Sort both arrays
|
||
baseline_sorted = np.sort(baseline_clean)
|
||
current_sorted = np.sort(current_clean)
|
||
|
||
# Create combined array of all values
|
||
all_values = np.concatenate([baseline_sorted, current_sorted])
|
||
all_values = np.sort(np.unique(all_values))
|
||
|
||
# Calculate CDFs
|
||
baseline_cdf = np.searchsorted(baseline_sorted, all_values, side='right') / len(baseline_sorted)
|
||
current_cdf = np.searchsorted(current_sorted, all_values, side='right') / len(current_sorted)
|
||
|
||
# KS statistic is the maximum difference
|
||
ks_stat = float(np.max(np.abs(baseline_cdf - current_cdf)))
|
||
|
||
# Approximate p-value using asymptotic formula
|
||
n1, n2 = len(baseline_clean), len(current_clean)
|
||
en = np.sqrt(n1 * n2 / (n1 + n2))
|
||
|
||
# Kolmogorov distribution approximation
|
||
lambda_val = (en + 0.12 + 0.11 / en) * ks_stat
|
||
|
||
# Two-sided p-value approximation
|
||
if lambda_val < 0.001:
|
||
p_value = 1.0
|
||
else:
|
||
# Approximation using exponential terms
|
||
j = np.arange(1, 101)
|
||
p_value = 2 * np.sum((-1) ** (j - 1) * np.exp(-2 * j ** 2 * lambda_val ** 2))
|
||
p_value = max(0.0, min(1.0, p_value))
|
||
|
||
return ks_stat, float(p_value)
|
||
|
||
|
||
def get_column_stats(conn: duckdb.DuckDBPyConnection, table_name: str, column: str) -> dict:
|
||
"""Get statistics for a column using DuckDB"""
|
||
try:
|
||
stats = conn.execute(f'''
|
||
SELECT
|
||
COUNT(*) as count,
|
||
COUNT("{column}") as non_null,
|
||
AVG("{column}"::DOUBLE) as mean,
|
||
STDDEV("{column}"::DOUBLE) as std,
|
||
MIN("{column}"::DOUBLE) as min,
|
||
MAX("{column}"::DOUBLE) as max,
|
||
PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY "{column}") as q1,
|
||
PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY "{column}") as median,
|
||
PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{column}") as q3
|
||
FROM {table_name}
|
||
''').fetchone()
|
||
|
||
return {
|
||
"count": stats[0],
|
||
"non_null": stats[1],
|
||
"mean": float(stats[2]) if stats[2] is not None else None,
|
||
"std": float(stats[3]) if stats[3] is not None else None,
|
||
"min": float(stats[4]) if stats[4] is not None else None,
|
||
"max": float(stats[5]) if stats[5] is not None else None,
|
||
"q1": float(stats[6]) if stats[6] is not None else None,
|
||
"median": float(stats[7]) if stats[7] is not None else None,
|
||
"q3": float(stats[8]) if stats[8] is not None else None
|
||
}
|
||
except Exception:
|
||
return {"count": 0, "non_null": 0}
|
||
|
||
|
||
def classify_drift(psi: float, ks_pvalue: float, psi_threshold: float, ks_threshold: float) -> tuple[bool, str]:
|
||
"""Classify drift severity based on PSI and KS test"""
|
||
is_drifted = psi >= psi_threshold or ks_pvalue < ks_threshold
|
||
|
||
if psi >= 0.25 or ks_pvalue < 0.01:
|
||
return True, "severe"
|
||
elif psi >= 0.2 or ks_pvalue < 0.05:
|
||
return True, "moderate"
|
||
elif psi >= 0.1:
|
||
return True, "minor"
|
||
else:
|
||
return is_drifted, "none"
|
||
|
||
|
||
def generate_recommendations(feature_scores: list[FeatureDrift], overall_drifted: bool) -> list[str]:
|
||
"""Generate actionable recommendations based on drift analysis"""
|
||
recommendations = []
|
||
|
||
severe_features = [f.feature for f in feature_scores if f.drift_type == "severe"]
|
||
moderate_features = [f.feature for f in feature_scores if f.drift_type == "moderate"]
|
||
minor_features = [f.feature for f in feature_scores if f.drift_type == "minor"]
|
||
|
||
if severe_features:
|
||
recommendations.append(f"🚨 CRITICAL: Severe drift detected in {len(severe_features)} feature(s): {', '.join(severe_features[:5])}. Immediate model retraining recommended.")
|
||
recommendations.append("Consider rolling back to a previous model version if performance degradation is observed.")
|
||
|
||
if moderate_features:
|
||
recommendations.append(f"⚠️ WARNING: Moderate drift in {len(moderate_features)} feature(s): {', '.join(moderate_features[:5])}. Schedule model retraining within 1-2 weeks.")
|
||
recommendations.append("Monitor model performance metrics closely for these features.")
|
||
|
||
if minor_features:
|
||
recommendations.append(f"ℹ️ INFO: Minor drift detected in {len(minor_features)} feature(s). Continue monitoring.")
|
||
|
||
if overall_drifted:
|
||
recommendations.append("Update baseline distributions after addressing drift to reset monitoring.")
|
||
recommendations.append("Investigate data pipeline changes that may have caused distribution shifts.")
|
||
recommendations.append("Consider feature engineering adjustments for drifted features.")
|
||
else:
|
||
recommendations.append("✅ No significant drift detected. Model distributions are stable.")
|
||
recommendations.append("Continue regular monitoring at current frequency.")
|
||
|
||
return recommendations
|
||
|
||
|
||
@router.post("/baseline")
|
||
async def upload_baseline(
|
||
file: UploadFile = File(...),
|
||
name: Optional[str] = Form(None)
|
||
):
|
||
"""Upload baseline distribution for comparison"""
|
||
try:
|
||
conn, table_name = await read_to_duckdb(file)
|
||
numeric_cols = get_numeric_columns(conn, table_name)
|
||
|
||
if not numeric_cols:
|
||
raise HTTPException(status_code=400, detail="No numeric columns found in the dataset")
|
||
|
||
# Generate baseline ID
|
||
baseline_id = hashlib.md5(f"{file.filename}_{datetime.now().isoformat()}".encode()).hexdigest()[:12]
|
||
|
||
# Store baseline statistics and raw values for each column
|
||
baseline_data = {
|
||
"id": baseline_id,
|
||
"name": name or file.filename,
|
||
"filename": file.filename,
|
||
"created_at": datetime.now().isoformat(),
|
||
"row_count": conn.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0],
|
||
"columns": {},
|
||
"values": {}
|
||
}
|
||
|
||
for col in numeric_cols:
|
||
baseline_data["columns"][col] = get_column_stats(conn, table_name, col)
|
||
# Store actual values for PSI/KS calculation
|
||
values = conn.execute(f'SELECT "{col}"::DOUBLE FROM {table_name} WHERE "{col}" IS NOT NULL').fetchall()
|
||
baseline_data["values"][col] = np.array([v[0] for v in values])
|
||
|
||
baselines_store[baseline_id] = baseline_data
|
||
|
||
conn.close()
|
||
|
||
return {
|
||
"message": "Baseline uploaded successfully",
|
||
"baseline_id": baseline_id,
|
||
"name": baseline_data["name"],
|
||
"filename": file.filename,
|
||
"row_count": baseline_data["row_count"],
|
||
"numeric_columns": numeric_cols,
|
||
"column_stats": baseline_data["columns"],
|
||
"engine": "DuckDB"
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"Error processing baseline file: {str(e)}")
|
||
|
||
|
||
@router.get("/baselines")
|
||
async def list_baselines():
|
||
"""List all stored baselines"""
|
||
return {
|
||
"baselines": [
|
||
{
|
||
"id": b["id"],
|
||
"name": b["name"],
|
||
"filename": b["filename"],
|
||
"created_at": b["created_at"],
|
||
"row_count": b["row_count"],
|
||
"columns": list(b["columns"].keys())
|
||
}
|
||
for b in baselines_store.values()
|
||
]
|
||
}
|
||
|
||
|
||
@router.delete("/baseline/{baseline_id}")
|
||
async def delete_baseline(baseline_id: str):
|
||
"""Delete a stored baseline"""
|
||
if baseline_id not in baselines_store:
|
||
raise HTTPException(status_code=404, detail="Baseline not found")
|
||
|
||
del baselines_store[baseline_id]
|
||
return {"message": "Baseline deleted", "baseline_id": baseline_id}
|
||
|
||
|
||
@router.post("/analyze")
|
||
async def analyze_drift(
|
||
production_file: UploadFile = File(...),
|
||
baseline_id: str = Form(...)
|
||
):
|
||
"""Analyze production data for drift against baseline"""
|
||
if baseline_id not in baselines_store:
|
||
raise HTTPException(status_code=404, detail=f"Baseline '{baseline_id}' not found. Upload a baseline first.")
|
||
|
||
try:
|
||
baseline = baselines_store[baseline_id]
|
||
conn, table_name = await read_to_duckdb(production_file)
|
||
|
||
numeric_cols = get_numeric_columns(conn, table_name)
|
||
common_cols = [col for col in numeric_cols if col in baseline["columns"]]
|
||
|
||
if not common_cols:
|
||
raise HTTPException(status_code=400, detail="No matching numeric columns found between production data and baseline")
|
||
|
||
feature_scores = []
|
||
total_psi = 0.0
|
||
drifted_count = 0
|
||
|
||
for col in common_cols:
|
||
# Get current values
|
||
current_values = conn.execute(f'SELECT "{col}"::DOUBLE FROM {table_name} WHERE "{col}" IS NOT NULL').fetchall()
|
||
current_arr = np.array([v[0] for v in current_values])
|
||
baseline_arr = baseline["values"][col]
|
||
|
||
# Calculate drift metrics
|
||
psi = calculate_psi(baseline_arr, current_arr)
|
||
ks_stat, ks_pvalue = calculate_ks_statistic(baseline_arr, current_arr)
|
||
|
||
# Classify drift
|
||
is_drifted, drift_type = classify_drift(psi, ks_pvalue, current_thresholds.psi_threshold, current_thresholds.ks_threshold)
|
||
|
||
if is_drifted:
|
||
drifted_count += 1
|
||
|
||
total_psi += psi
|
||
|
||
feature_scores.append(FeatureDrift(
|
||
feature=col,
|
||
psi_score=round(psi, 4),
|
||
ks_statistic=round(ks_stat, 4),
|
||
ks_pvalue=round(ks_pvalue, 4),
|
||
is_drifted=is_drifted,
|
||
drift_type=drift_type,
|
||
baseline_stats=baseline["columns"][col],
|
||
current_stats=get_column_stats(conn, table_name, col)
|
||
))
|
||
|
||
conn.close()
|
||
|
||
# Calculate overall drift
|
||
avg_psi = total_psi / len(common_cols) if common_cols else 0
|
||
overall_drifted = drifted_count > 0
|
||
|
||
# Determine severity
|
||
severe_count = len([f for f in feature_scores if f.drift_type == "severe"])
|
||
moderate_count = len([f for f in feature_scores if f.drift_type == "moderate"])
|
||
|
||
if severe_count > 0:
|
||
drift_severity = "severe"
|
||
elif moderate_count > 0:
|
||
drift_severity = "moderate"
|
||
elif drifted_count > 0:
|
||
drift_severity = "minor"
|
||
else:
|
||
drift_severity = "none"
|
||
|
||
# Generate recommendations
|
||
recommendations = generate_recommendations(feature_scores, overall_drifted)
|
||
|
||
# Create result
|
||
result = DriftResult(
|
||
is_drifted=overall_drifted,
|
||
overall_score=round(avg_psi, 4),
|
||
drift_severity=drift_severity,
|
||
drifted_features=drifted_count,
|
||
total_features=len(common_cols),
|
||
feature_scores=feature_scores,
|
||
method="PSI + Kolmogorov-Smirnov",
|
||
recommendations=recommendations,
|
||
timestamp=datetime.now().isoformat(),
|
||
engine="DuckDB"
|
||
)
|
||
|
||
# Store in history
|
||
drift_history.append({
|
||
"baseline_id": baseline_id,
|
||
"production_file": production_file.filename,
|
||
"timestamp": result.timestamp,
|
||
"is_drifted": result.is_drifted,
|
||
"overall_score": result.overall_score,
|
||
"drift_severity": result.drift_severity,
|
||
"drifted_features": result.drifted_features,
|
||
"total_features": result.total_features
|
||
})
|
||
|
||
return result
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"Error analyzing drift: {str(e)}")
|
||
|
||
|
||
@router.post("/compare-files")
|
||
async def compare_two_files(
|
||
baseline_file: UploadFile = File(...),
|
||
production_file: UploadFile = File(...)
|
||
):
|
||
"""Compare two files directly without storing baseline"""
|
||
try:
|
||
# Load both files
|
||
baseline_conn, baseline_table = await read_to_duckdb(baseline_file)
|
||
|
||
# Need to reset file position for second read
|
||
production_content = await production_file.read()
|
||
|
||
# Create production connection
|
||
prod_conn = duckdb.connect(":memory:")
|
||
filename = production_file.filename.lower() if production_file.filename else "file.csv"
|
||
suffix = '.csv' if filename.endswith('.csv') else '.json' if filename.endswith('.json') else '.csv'
|
||
|
||
with tempfile.NamedTemporaryFile(mode='wb', suffix=suffix, delete=False) as tmp:
|
||
tmp.write(production_content)
|
||
tmp_path = tmp.name
|
||
|
||
try:
|
||
if filename.endswith('.csv'):
|
||
prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
|
||
elif filename.endswith('.json'):
|
||
prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_json_auto('{tmp_path}')")
|
||
else:
|
||
prod_conn.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{tmp_path}')")
|
||
finally:
|
||
os.unlink(tmp_path)
|
||
|
||
prod_table = "data"
|
||
|
||
# Get common numeric columns
|
||
baseline_cols = get_numeric_columns(baseline_conn, baseline_table)
|
||
prod_cols = get_numeric_columns(prod_conn, prod_table)
|
||
common_cols = list(set(baseline_cols) & set(prod_cols))
|
||
|
||
if not common_cols:
|
||
raise HTTPException(status_code=400, detail="No matching numeric columns found between files")
|
||
|
||
feature_scores = []
|
||
total_psi = 0.0
|
||
drifted_count = 0
|
||
|
||
for col in common_cols:
|
||
# Get values from both files
|
||
baseline_values = baseline_conn.execute(f'SELECT "{col}"::DOUBLE FROM {baseline_table} WHERE "{col}" IS NOT NULL').fetchall()
|
||
prod_values = prod_conn.execute(f'SELECT "{col}"::DOUBLE FROM {prod_table} WHERE "{col}" IS NOT NULL').fetchall()
|
||
|
||
baseline_arr = np.array([v[0] for v in baseline_values])
|
||
prod_arr = np.array([v[0] for v in prod_values])
|
||
|
||
# Calculate drift metrics
|
||
psi = calculate_psi(baseline_arr, prod_arr)
|
||
ks_stat, ks_pvalue = calculate_ks_statistic(baseline_arr, prod_arr)
|
||
|
||
# Classify drift
|
||
is_drifted, drift_type = classify_drift(psi, ks_pvalue, current_thresholds.psi_threshold, current_thresholds.ks_threshold)
|
||
|
||
if is_drifted:
|
||
drifted_count += 1
|
||
|
||
total_psi += psi
|
||
|
||
feature_scores.append(FeatureDrift(
|
||
feature=col,
|
||
psi_score=round(psi, 4),
|
||
ks_statistic=round(ks_stat, 4),
|
||
ks_pvalue=round(ks_pvalue, 4),
|
||
is_drifted=is_drifted,
|
||
drift_type=drift_type,
|
||
baseline_stats=get_column_stats(baseline_conn, baseline_table, col),
|
||
current_stats=get_column_stats(prod_conn, prod_table, col)
|
||
))
|
||
|
||
baseline_conn.close()
|
||
prod_conn.close()
|
||
|
||
# Calculate overall drift
|
||
avg_psi = total_psi / len(common_cols) if common_cols else 0
|
||
overall_drifted = drifted_count > 0
|
||
|
||
# Determine severity
|
||
severe_count = len([f for f in feature_scores if f.drift_type == "severe"])
|
||
moderate_count = len([f for f in feature_scores if f.drift_type == "moderate"])
|
||
|
||
if severe_count > 0:
|
||
drift_severity = "severe"
|
||
elif moderate_count > 0:
|
||
drift_severity = "moderate"
|
||
elif drifted_count > 0:
|
||
drift_severity = "minor"
|
||
else:
|
||
drift_severity = "none"
|
||
|
||
recommendations = generate_recommendations(feature_scores, overall_drifted)
|
||
|
||
return DriftResult(
|
||
is_drifted=overall_drifted,
|
||
overall_score=round(avg_psi, 4),
|
||
drift_severity=drift_severity,
|
||
drifted_features=drifted_count,
|
||
total_features=len(common_cols),
|
||
feature_scores=feature_scores,
|
||
method="PSI + Kolmogorov-Smirnov",
|
||
recommendations=recommendations,
|
||
timestamp=datetime.now().isoformat(),
|
||
engine="DuckDB"
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"Error comparing files: {str(e)}")
|
||
|
||
|
||
@router.get("/history")
|
||
async def get_drift_history(limit: int = 100):
|
||
"""Get historical drift analysis results"""
|
||
return {
|
||
"history": drift_history[-limit:],
|
||
"total_analyses": len(drift_history)
|
||
}
|
||
|
||
|
||
@router.put("/thresholds")
|
||
async def update_thresholds(thresholds: DriftThresholds):
|
||
"""Update drift detection thresholds"""
|
||
global current_thresholds
|
||
current_thresholds = thresholds
|
||
return {
|
||
"message": "Thresholds updated",
|
||
"thresholds": {
|
||
"psi_threshold": current_thresholds.psi_threshold,
|
||
"ks_threshold": current_thresholds.ks_threshold,
|
||
"alert_enabled": current_thresholds.alert_enabled
|
||
}
|
||
}
|
||
|
||
|
||
@router.get("/thresholds")
|
||
async def get_thresholds():
|
||
"""Get current drift detection thresholds"""
|
||
return {
|
||
"psi_threshold": current_thresholds.psi_threshold,
|
||
"ks_threshold": current_thresholds.ks_threshold,
|
||
"alert_enabled": current_thresholds.alert_enabled,
|
||
"psi_interpretation": {
|
||
"low": "PSI < 0.1 - No significant change",
|
||
"moderate": "0.1 <= PSI < 0.2 - Moderate change, monitoring needed",
|
||
"high": "PSI >= 0.2 - Significant change, action required"
|
||
}
|
||
}
|