""" PhotosynthesisPredictor: train and evaluate regression models on IMS features; report RMSE, MAE, R2. """ from __future__ import annotations from pathlib import Path from typing import Optional import numpy as np import pandas as pd from sklearn.linear_model import LinearRegression from sklearn.tree import DecisionTreeRegressor from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score try: from xgboost import XGBRegressor _HAS_XGB = True except ImportError: _HAS_XGB = False try: import matplotlib.pyplot as plt _HAS_PLOT = True except ImportError: _HAS_PLOT = False class PhotosynthesisPredictor: """Train multiple regressors and evaluate on test set.""" def __init__(self): self.models: dict = { "LinearRegression": LinearRegression(), "DecisionTree": DecisionTreeRegressor(max_depth=6, min_samples_leaf=10), "RandomForest": RandomForestRegressor( n_estimators=200, max_depth=8, min_samples_leaf=5, n_jobs=-1, random_state=42, ), "GradientBoosting": GradientBoostingRegressor( n_estimators=300, max_depth=4, learning_rate=0.05, min_samples_leaf=10, random_state=42, ), } if _HAS_XGB: self.models["XGBoost"] = XGBRegressor( n_estimators=300, max_depth=4, learning_rate=0.05, min_child_weight=10, reg_alpha=0.1, reg_lambda=1.0, n_jobs=-1, random_state=42, ) self.results: dict[str, dict] = {} def train(self, X_train: pd.DataFrame, y_train: pd.Series) -> None: """Fit all models on (X_train, y_train).""" for name, model in self.models.items(): model.fit(X_train, y_train) def evaluate( self, X_test: pd.DataFrame, y_test: pd.Series, ) -> pd.DataFrame: """ Predict with each model, compute RMSE, MAE, R2. Return comparison table. """ rows = [] for name, model in self.models.items(): pred = model.predict(X_test) rmse = float(np.sqrt(mean_squared_error(y_test, pred))) mae = float(mean_absolute_error(y_test, pred)) r2 = float(r2_score(y_test, pred)) self.results[name] = {"predictions": pred, "rmse": rmse, "mae": mae, "r2": r2} rows.append({"model": name, "RMSE": rmse, "MAE": mae, "R2": r2}) return pd.DataFrame(rows) def get_feature_importance(self, model_name: str | None = None) -> pd.DataFrame: """ Return feature importance from tree-based models. Prefers XGBoost > GradientBoosting > RandomForest > DecisionTree. """ if model_name: candidates = [model_name] else: candidates = ["XGBoost", "GradientBoosting", "RandomForest", "DecisionTree"] for name in candidates: m = self.models.get(name) if m is not None and hasattr(m, "feature_importances_"): imp = m.feature_importances_ return pd.DataFrame({ "feature": getattr(m, "feature_names_in_", list(range(len(imp)))), "importance": imp, }).sort_values("importance", ascending=False) return pd.DataFrame() def plot_results( self, y_test: pd.Series, predictions: Optional[dict[str, np.ndarray]] = None, save_path: Optional[Path] = None, ) -> None: """ Predicted vs approx A scatter and optional time series overlay. predictions: dict model_name -> pred array; if None use self.results. """ if not _HAS_PLOT: return preds = predictions or {n: self.results[n]["predictions"] for n in self.results} if not preds: return fig, axes = plt.subplots(1, 2, figsize=(12, 5)) # Scatter: pick best model by R2 best = max(self.results, key=lambda n: self.results[n].get("r2", -999)) if self.results else list(preds.keys())[0] name = best if best in preds else list(preds.keys())[0] ax = axes[0] ax.scatter(y_test, preds[name], alpha=0.5, s=10) mn = min(y_test.min(), preds[name].min()) mx = max(y_test.max(), preds[name].max()) ax.plot([mn, mx], [mn, mx], "k--", label="1:1") ax.set_xlabel("Approx A (µmol m⁻² s⁻¹)") ax.set_ylabel("Predicted A") ax.set_title(f"Predicted vs approx A ({name})") ax.legend() ax.set_aspect("equal") # Time series overlay — show top 2 models by R2 ax = axes[1] ax.plot(y_test.values, label="Approx A", alpha=0.8) ranked = sorted(self.results, key=lambda n: self.results[n].get("r2", -999), reverse=True) for n in ranked[:2]: if n in preds: ax.plot(preds[n], label=f"{n} (R²={self.results[n]['r2']:.2f})", alpha=0.7) ax.set_xlabel("Time index") ax.set_ylabel("A (umol m-2 s-1)") ax.set_title("Time series overlay") ax.legend() plt.tight_layout() if save_path: save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150) plt.close()