bosh94's picture
Update app.py
de1701c verified
import os
import time
import inspect
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import gradio as gr
import torch
import plotly.graph_objects as go
from chronos import Chronos2Pipeline
MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
DATA_DIR = "data"
OUT_DIR = "/tmp"
# -------------------------
# Data
# -------------------------
def available_test_csv() -> List[str]:
if not os.path.isdir(DATA_DIR):
return []
return sorted([f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv")])
def pick_device(ui_choice: str) -> str:
return "cuda" if (ui_choice or "").startswith("cuda") and torch.cuda.is_available() else "cpu"
def make_sample_series(n: int, seed: int, trend: float, season_period: int, season_amp: float, noise: float) -> np.ndarray:
rng = np.random.default_rng(int(seed))
t = np.arange(int(n), dtype=np.float32)
y = (trend * t + season_amp * np.sin(2 * np.pi * t / max(1, int(season_period))) + rng.normal(0, noise, size=int(n))).astype(np.float32)
if float(np.min(y)) < 0:
y -= float(np.min(y))
return y
def load_series_from_csv(csv_path: str, column: Optional[str]) -> Tuple[np.ndarray, str]:
df = pd.read_csv(csv_path)
col = (column or "").strip()
if not col:
numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
if not numeric_cols:
# try coercion
for c in df.columns:
coerced = pd.to_numeric(df[c], errors="coerce")
if coerced.notna().sum() > 0:
numeric_cols.append(c)
if not numeric_cols:
raise ValueError("Non trovo colonne numeriche nel CSV.")
col = numeric_cols[0]
if col not in df.columns:
raise ValueError(f"Colonna '{col}' non trovata. Disponibili: {list(df.columns)}")
y = pd.to_numeric(df[col], errors="coerce").dropna().astype(np.float32).to_numpy()
if len(y) < 10:
raise ValueError("Serie troppo corta.")
return y, col
# -------------------------
# Model cache
# -------------------------
_PIPE = None
_META = {"model_id": None, "device": None}
def get_pipeline(model_id: str, device: str) -> Chronos2Pipeline:
global _PIPE, _META
model_id = (model_id or MODEL_ID_DEFAULT).strip()
device = "cuda" if device == "cuda" and torch.cuda.is_available() else "cpu"
if _PIPE is None or _META["model_id"] != model_id or _META["device"] != device:
_PIPE = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
_META = {"model_id": model_id, "device": device}
return _PIPE
# -------------------------
# Predict (STABLE)
# -------------------------
def _to_numpy(x: Any) -> np.ndarray:
if isinstance(x, np.ndarray):
return x
if torch.is_tensor(x):
return x.detach().cpu().numpy()
return np.asarray(x)
def _extract_samples(raw: Any) -> np.ndarray:
if isinstance(raw, dict):
for k in ["samples", "predictions", "prediction", "output"]:
if k in raw:
return _to_numpy(raw[k])
if len(raw) > 0:
return _to_numpy(next(iter(raw.values())))
return np.asarray([], dtype=np.float32)
return _to_numpy(raw)
def chronos2_predict(pipe: Chronos2Pipeline, y: np.ndarray, horizon: int, requested_samples: int) -> Tuple[np.ndarray, bool, str]:
"""
Returns:
samples: (S, H)
multi: whether S>1 is real (not replicated)
note: debug note
"""
sig = inspect.signature(pipe.predict)
params = sig.parameters
# input format: ALWAYS batch = [series]
inputs = [y.tolist()]
# kw for horizon
horizon_kw = None
for cand in ["prediction_length", "horizon", "steps", "n_steps", "pred_len"]:
if cand in params:
horizon_kw = cand
break
# kw for samples count (many versions don't have it!)
sample_kw = None
for cand in ["n_samples", "num_return_sequences", "num_samples"]:
if cand in params:
sample_kw = cand
break
kwargs: Dict[str, Any] = {}
if horizon_kw:
kwargs[horizon_kw] = int(horizon)
else:
# worst case: try positional horizon if supported (rare)
kwargs["prediction_length"] = int(horizon)
if sample_kw:
kwargs[sample_kw] = int(requested_samples)
# call
raw = pipe.predict(inputs=inputs, **kwargs) if "inputs" in params else pipe.predict(inputs, **kwargs)
arr = _extract_samples(raw).astype(np.float32, copy=False)
# normalize shape -> (S,H)
arr = np.squeeze(arr)
if arr.ndim == 1:
# could be (H,) or (S,) - assume horizon if length == H
arr = arr[None, :]
# Sometimes output is (B,S,H) or (B,H). If batch dim exists, take first
if arr.ndim == 3:
# assume (B,S,H) or (S,B,H); safest: pick first on axis=0
arr = arr[0]
if arr.ndim == 1:
arr = arr[None, :]
# ensure horizon length
if arr.shape[-1] != horizon:
if arr.shape[-1] > horizon:
arr = arr[..., :horizon]
else:
pad = horizon - arr.shape[-1]
last = arr[..., -1:]
arr = np.concatenate([arr, np.repeat(last, pad, axis=-1)], axis=-1)
# If we got only 1 sample, we can still plot median but band is not meaningful
real_multi = arr.shape[0] > 1
note = f"predict_signature={sig} | used_horizon_kw={horizon_kw} | used_sample_kw={sample_kw} | got_shape={tuple(arr.shape)}"
return arr, real_multi, note
# -------------------------
# Plotly
# -------------------------
def plot_forecast(y, median, low, high, title, show_band: bool, band_label: str) -> go.Figure:
t_hist = np.arange(len(y))
t_fcst = np.arange(len(y), len(y) + len(median))
fig = go.Figure()
fig.add_trace(go.Scatter(x=t_hist, y=y, mode="lines", name="History"))
fig.add_trace(go.Scatter(x=t_fcst, y=median, mode="lines", name="Forecast (median)"))
fig.add_vline(x=len(y) - 1, line_width=1, line_dash="dash", opacity=0.6)
if show_band:
fig.add_trace(go.Scatter(x=t_fcst, y=high, mode="lines", line=dict(width=0),
showlegend=False, hoverinfo="skip"))
fig.add_trace(go.Scatter(
x=t_fcst, y=low, mode="lines", fill="tonexty",
line=dict(width=0), name=band_label
))
fig.update_layout(
title=title,
hovermode="x unified",
margin=dict(l=10, r=10, t=55, b=10),
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
xaxis_title="t",
yaxis_title="value",
)
return fig
def kpi_card(label: str, value: str, hint: str = "") -> str:
hint_html = f"<div style='opacity:.75;font-size:12px;margin-top:6px;'>{hint}</div>" if hint else ""
return f"""
<div style="border:1px solid rgba(255,255,255,.12); border-radius:16px; padding:14px 16px;
background: rgba(255,255,255,.04);">
<div style="font-size:12px;opacity:.8;">{label}</div>
<div style="font-size:22px;font-weight:700;margin-top:4px;">{value}</div>
{hint_html}
</div>
"""
def kpi_grid(cards: List[str]) -> str:
return f"<div style='display:grid; grid-template-columns: repeat(6, minmax(0, 1fr)); gap:12px;'>{''.join(cards)}</div>"
def explain(y, median, low, high, band_enabled: bool, q_low: float, q_high: float, extra: str) -> str:
horizon = len(median)
base = float(np.mean(y))
delta = float(median[-1] - median[0])
pct = (delta / max(1e-6, base)) * 100.0
if abs(pct) < 2:
trend_txt = "sostanzialmente stabile"
elif pct > 0:
trend_txt = "in crescita"
else:
trend_txt = "in calo"
txt = f"""
### 🧠 Spiegazione
Nei prossimi **{horizon} step** la previsione mediana è **{trend_txt}** (variazione ≈ **{pct:+.1f}%** rispetto al livello medio storico).
- **Ultimo valore mediano previsto:** **{median[-1]:.2f}**
"""
if band_enabled:
txt += f"- **Banda [{q_low:.0%}{q_high:.0%}] (ultimo step):** **[{low[-1]:.2f}{high[-1]:.2f}]**\n"
else:
txt += "- **Banda di incertezza:** disattivata (questa versione di Chronos2 non restituisce campioni multipli con i parametri disponibili).\n"
txt += f"\n<details><summary>Debug</summary>\n\n`{extra}`\n\n</details>\n"
return txt
# -------------------------
# Run
# -------------------------
def run_all(
input_mode, test_csv_name, upload_csv, csv_column,
n, seed, trend, season_period, season_amp, noise,
prediction_length, requested_samples, q_low, q_high,
device_ui, model_id,
):
if q_low >= q_high:
raise gr.Error("Quantile low deve essere < quantile high.")
device = pick_device(device_ui)
pipe = get_pipeline(model_id, device)
# data
if input_mode == "Test CSV":
if not test_csv_name:
raise gr.Error("Seleziona un Test CSV.")
path = os.path.join(DATA_DIR, test_csv_name)
y, used_col = load_series_from_csv(path, csv_column)
source = f"Test CSV: {test_csv_name} • col={used_col}"
elif input_mode == "Upload CSV":
if upload_csv is None:
raise gr.Error("Carica un CSV.")
y, used_col = load_series_from_csv(upload_csv.name, csv_column)
source = f"Upload CSV • col={used_col}"
else:
y = make_sample_series(n, seed, trend, season_period, season_amp, noise)
source = "Sample series"
t0 = time.time()
samples, real_multi, note = chronos2_predict(pipe, y, int(prediction_length), int(requested_samples))
latency = time.time() - t0
median = np.quantile(samples, 0.50, axis=0)
band_enabled = real_multi and samples.shape[0] > 2
if band_enabled:
low = np.quantile(samples, float(q_low), axis=0)
high = np.quantile(samples, float(q_high), axis=0)
else:
low = median.copy()
high = median.copy()
# KPI
cards = [
kpi_card("Device", device.upper(), f"cuda_available={torch.cuda.is_available()}"),
kpi_card("Latency", f"{latency:.2f}s", "predict()"),
kpi_card("Samples", str(samples.shape[0]), "returned by model"),
kpi_card("Band", "ON" if band_enabled else "OFF", "needs multi-samples"),
kpi_card("Horizon", str(prediction_length)),
kpi_card("Model", (model_id or MODEL_ID_DEFAULT)),
]
kpis_html = kpi_grid(cards)
# Plot
fig = plot_forecast(
y=y,
median=median,
low=low,
high=high,
title=f"Forecast — {source}",
show_band=band_enabled,
band_label=f"Band [{q_low:.2f}, {q_high:.2f}]",
)
# Table + export
t_fcst = np.arange(len(y), len(y) + int(prediction_length))
out_df = pd.DataFrame({
"t": t_fcst,
"median": median,
})
if band_enabled:
out_df[f"q{q_low:.2f}"] = low
out_df[f"q{q_high:.2f}"] = high
out_path = os.path.join(OUT_DIR, "chronos2_forecast.csv")
out_df.to_csv(out_path, index=False)
explanation_md = explain(y, median, low, high, band_enabled, q_low, q_high, note)
info = {
"source": source,
"history_points": int(len(y)),
"prediction_length": int(prediction_length),
"requested_samples": int(requested_samples),
"returned_samples": int(samples.shape[0]),
"band_enabled": bool(band_enabled),
"predict_signature": str(inspect.signature(pipe.predict)),
"debug_note": note,
}
return kpis_html, explanation_md, fig, out_df, out_path, info
# -------------------------
# UI
# -------------------------
css = """.gradio-container { max-width: 1200px !important; }"""
with gr.Blocks(title="Chronos-2 • Pro Dashboard (Stable)", css=css) as demo:
gr.Markdown("# ⏱️ Chronos-2 Forecast Dashboard — Stable Edition")
with gr.Row():
with gr.Column(scale=1, min_width=360):
input_mode = gr.Radio(["Sample", "Test CSV", "Upload CSV"], value="Sample", label="Input")
test_csv_name = gr.Dropdown(choices=available_test_csv(), label="Test CSV (data/)")
upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value")
device_ui = gr.Dropdown(
["cpu", "cuda (se disponibile)"],
value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu",
label="Device",
)
model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID")
with gr.Accordion("Sample generator", open=False):
n = gr.Slider(60, 2000, value=300, step=10, label="History length")
seed = gr.Number(value=42, precision=0, label="Seed")
trend = gr.Slider(0.0, 0.2, value=0.03, step=0.005, label="Trend")
season_period = gr.Slider(2, 240, value=14, step=1, label="Season period")
season_amp = gr.Slider(0.0, 12.0, value=3.0, step=0.1, label="Season amplitude")
noise = gr.Slider(0.0, 6.0, value=0.8, step=0.05, label="Noise")
prediction_length = gr.Slider(1, 365, value=30, step=1, label="Prediction length")
requested_samples = gr.Slider(1, 800, value=200, step=25, label="Requested samples (best effort)")
q_low = gr.Slider(0.01, 0.49, value=0.10, step=0.01, label="Quantile low")
q_high = gr.Slider(0.51, 0.99, value=0.90, step=0.01, label="Quantile high")
run_btn = gr.Button("Run", variant="primary")
with gr.Column(scale=2):
kpis = gr.HTML()
with gr.Tabs():
with gr.Tab("Forecast"):
forecast_plot = gr.Plot()
forecast_table = gr.Dataframe(interactive=False)
with gr.Tab("Spiegazione"):
explanation = gr.Markdown()
with gr.Tab("Export"):
download = gr.File()
with gr.Tab("Info"):
info = gr.JSON()
run_btn.click(
fn=run_all,
inputs=[
input_mode, test_csv_name, upload_csv, csv_column,
n, seed, trend, season_period, season_amp, noise,
prediction_length, requested_samples, q_low, q_high,
device_ui, model_id,
],
outputs=[kpis, explanation, forecast_plot, forecast_table, download, info],
)
demo.queue()
demo.launch(ssr_mode=False)