| """ |
| PhotosynthesisPredictor: train and evaluate regression models on IMS |
| features; report RMSE, MAE, R2. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Optional |
|
|
| import numpy as np |
| import pandas as pd |
| from sklearn.linear_model import LinearRegression |
| from sklearn.tree import DecisionTreeRegressor |
| from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor |
| from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score |
|
|
| try: |
| from xgboost import XGBRegressor |
| _HAS_XGB = True |
| except ImportError: |
| _HAS_XGB = False |
|
|
| try: |
| import matplotlib.pyplot as plt |
| _HAS_PLOT = True |
| except ImportError: |
| _HAS_PLOT = False |
|
|
|
|
| class PhotosynthesisPredictor: |
| """Train multiple regressors and evaluate on test set.""" |
|
|
| def __init__(self): |
| self.models: dict = { |
| "LinearRegression": LinearRegression(), |
| "DecisionTree": DecisionTreeRegressor(max_depth=6, min_samples_leaf=10), |
| "RandomForest": RandomForestRegressor( |
| n_estimators=200, max_depth=8, min_samples_leaf=5, |
| n_jobs=-1, random_state=42, |
| ), |
| "GradientBoosting": GradientBoostingRegressor( |
| n_estimators=300, max_depth=4, learning_rate=0.05, |
| min_samples_leaf=10, random_state=42, |
| ), |
| } |
| if _HAS_XGB: |
| self.models["XGBoost"] = XGBRegressor( |
| n_estimators=300, max_depth=4, learning_rate=0.05, |
| min_child_weight=10, reg_alpha=0.1, reg_lambda=1.0, |
| n_jobs=-1, random_state=42, |
| ) |
| self.results: dict[str, dict] = {} |
|
|
| def train(self, X_train: pd.DataFrame, y_train: pd.Series) -> None: |
| """Fit all models on (X_train, y_train).""" |
| for name, model in self.models.items(): |
| model.fit(X_train, y_train) |
|
|
| def evaluate( |
| self, |
| X_test: pd.DataFrame, |
| y_test: pd.Series, |
| ) -> pd.DataFrame: |
| """ |
| Predict with each model, compute RMSE, MAE, R2. Return comparison table. |
| """ |
| rows = [] |
| for name, model in self.models.items(): |
| pred = model.predict(X_test) |
| rmse = float(np.sqrt(mean_squared_error(y_test, pred))) |
| mae = float(mean_absolute_error(y_test, pred)) |
| r2 = float(r2_score(y_test, pred)) |
| self.results[name] = {"predictions": pred, "rmse": rmse, "mae": mae, "r2": r2} |
| rows.append({"model": name, "RMSE": rmse, "MAE": mae, "R2": r2}) |
| return pd.DataFrame(rows) |
|
|
| def get_feature_importance(self, model_name: str | None = None) -> pd.DataFrame: |
| """ |
| Return feature importance from tree-based models. |
| Prefers XGBoost > GradientBoosting > RandomForest > DecisionTree. |
| """ |
| if model_name: |
| candidates = [model_name] |
| else: |
| candidates = ["XGBoost", "GradientBoosting", "RandomForest", "DecisionTree"] |
| for name in candidates: |
| m = self.models.get(name) |
| if m is not None and hasattr(m, "feature_importances_"): |
| imp = m.feature_importances_ |
| return pd.DataFrame({ |
| "feature": getattr(m, "feature_names_in_", list(range(len(imp)))), |
| "importance": imp, |
| }).sort_values("importance", ascending=False) |
| return pd.DataFrame() |
|
|
| def plot_results( |
| self, |
| y_test: pd.Series, |
| predictions: Optional[dict[str, np.ndarray]] = None, |
| save_path: Optional[Path] = None, |
| ) -> None: |
| """ |
| Predicted vs approx A scatter and optional time series overlay. |
| predictions: dict model_name -> pred array; if None use self.results. |
| """ |
| if not _HAS_PLOT: |
| return |
| preds = predictions or {n: self.results[n]["predictions"] for n in self.results} |
| if not preds: |
| return |
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) |
| |
| best = max(self.results, key=lambda n: self.results[n].get("r2", -999)) if self.results else list(preds.keys())[0] |
| name = best if best in preds else list(preds.keys())[0] |
| ax = axes[0] |
| ax.scatter(y_test, preds[name], alpha=0.5, s=10) |
| mn = min(y_test.min(), preds[name].min()) |
| mx = max(y_test.max(), preds[name].max()) |
| ax.plot([mn, mx], [mn, mx], "k--", label="1:1") |
| ax.set_xlabel("Approx A (µmol m⁻² s⁻¹)") |
| ax.set_ylabel("Predicted A") |
| ax.set_title(f"Predicted vs approx A ({name})") |
| ax.legend() |
| ax.set_aspect("equal") |
| |
| ax = axes[1] |
| ax.plot(y_test.values, label="Approx A", alpha=0.8) |
| ranked = sorted(self.results, key=lambda n: self.results[n].get("r2", -999), reverse=True) |
| for n in ranked[:2]: |
| if n in preds: |
| ax.plot(preds[n], label=f"{n} (R²={self.results[n]['r2']:.2f})", alpha=0.7) |
| ax.set_xlabel("Time index") |
| ax.set_ylabel("A (umol m-2 s-1)") |
| ax.set_title("Time series overlay") |
| ax.legend() |
| plt.tight_layout() |
| if save_path: |
| save_path.parent.mkdir(parents=True, exist_ok=True) |
| plt.savefig(save_path, dpi=150) |
| plt.close() |
|
|