api / src /forecasting /predictor.py
Eli Safra
Deploy SolarWine API (FastAPI + Docker, port 7860)
938949f
"""
PhotosynthesisPredictor: train and evaluate regression models on IMS
features; report RMSE, MAE, R2.
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
try:
from xgboost import XGBRegressor
_HAS_XGB = True
except ImportError:
_HAS_XGB = False
try:
import matplotlib.pyplot as plt
_HAS_PLOT = True
except ImportError:
_HAS_PLOT = False
class PhotosynthesisPredictor:
"""Train multiple regressors and evaluate on test set."""
def __init__(self):
self.models: dict = {
"LinearRegression": LinearRegression(),
"DecisionTree": DecisionTreeRegressor(max_depth=6, min_samples_leaf=10),
"RandomForest": RandomForestRegressor(
n_estimators=200, max_depth=8, min_samples_leaf=5,
n_jobs=-1, random_state=42,
),
"GradientBoosting": GradientBoostingRegressor(
n_estimators=300, max_depth=4, learning_rate=0.05,
min_samples_leaf=10, random_state=42,
),
}
if _HAS_XGB:
self.models["XGBoost"] = XGBRegressor(
n_estimators=300, max_depth=4, learning_rate=0.05,
min_child_weight=10, reg_alpha=0.1, reg_lambda=1.0,
n_jobs=-1, random_state=42,
)
self.results: dict[str, dict] = {}
def train(self, X_train: pd.DataFrame, y_train: pd.Series) -> None:
"""Fit all models on (X_train, y_train)."""
for name, model in self.models.items():
model.fit(X_train, y_train)
def evaluate(
self,
X_test: pd.DataFrame,
y_test: pd.Series,
) -> pd.DataFrame:
"""
Predict with each model, compute RMSE, MAE, R2. Return comparison table.
"""
rows = []
for name, model in self.models.items():
pred = model.predict(X_test)
rmse = float(np.sqrt(mean_squared_error(y_test, pred)))
mae = float(mean_absolute_error(y_test, pred))
r2 = float(r2_score(y_test, pred))
self.results[name] = {"predictions": pred, "rmse": rmse, "mae": mae, "r2": r2}
rows.append({"model": name, "RMSE": rmse, "MAE": mae, "R2": r2})
return pd.DataFrame(rows)
def get_feature_importance(self, model_name: str | None = None) -> pd.DataFrame:
"""
Return feature importance from tree-based models.
Prefers XGBoost > GradientBoosting > RandomForest > DecisionTree.
"""
if model_name:
candidates = [model_name]
else:
candidates = ["XGBoost", "GradientBoosting", "RandomForest", "DecisionTree"]
for name in candidates:
m = self.models.get(name)
if m is not None and hasattr(m, "feature_importances_"):
imp = m.feature_importances_
return pd.DataFrame({
"feature": getattr(m, "feature_names_in_", list(range(len(imp)))),
"importance": imp,
}).sort_values("importance", ascending=False)
return pd.DataFrame()
def plot_results(
self,
y_test: pd.Series,
predictions: Optional[dict[str, np.ndarray]] = None,
save_path: Optional[Path] = None,
) -> None:
"""
Predicted vs approx A scatter and optional time series overlay.
predictions: dict model_name -> pred array; if None use self.results.
"""
if not _HAS_PLOT:
return
preds = predictions or {n: self.results[n]["predictions"] for n in self.results}
if not preds:
return
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Scatter: pick best model by R2
best = max(self.results, key=lambda n: self.results[n].get("r2", -999)) if self.results else list(preds.keys())[0]
name = best if best in preds else list(preds.keys())[0]
ax = axes[0]
ax.scatter(y_test, preds[name], alpha=0.5, s=10)
mn = min(y_test.min(), preds[name].min())
mx = max(y_test.max(), preds[name].max())
ax.plot([mn, mx], [mn, mx], "k--", label="1:1")
ax.set_xlabel("Approx A (µmol m⁻² s⁻¹)")
ax.set_ylabel("Predicted A")
ax.set_title(f"Predicted vs approx A ({name})")
ax.legend()
ax.set_aspect("equal")
# Time series overlay — show top 2 models by R2
ax = axes[1]
ax.plot(y_test.values, label="Approx A", alpha=0.8)
ranked = sorted(self.results, key=lambda n: self.results[n].get("r2", -999), reverse=True)
for n in ranked[:2]:
if n in preds:
ax.plot(preds[n], label=f"{n} (R²={self.results[n]['r2']:.2f})", alpha=0.7)
ax.set_xlabel("Time index")
ax.set_ylabel("A (umol m-2 s-1)")
ax.set_title("Time series overlay")
ax.legend()
plt.tight_layout()
if save_path:
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150)
plt.close()