bosh94 commited on
Commit
9c986ec
·
verified ·
1 Parent(s): d509c45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +510 -223
app.py CHANGED
@@ -1,9 +1,14 @@
1
  import os
 
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import gradio as gr
5
- import matplotlib.pyplot as plt
6
  import torch
 
7
 
8
  from chronos import Chronos2Pipeline
9
 
@@ -13,322 +18,604 @@ from chronos import Chronos2Pipeline
13
  # =========================
14
  MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
15
  DATA_DIR = "data"
 
16
 
17
 
18
  # =========================
19
- # Utils
20
  # =========================
21
- def available_test_csv():
22
  if not os.path.isdir(DATA_DIR):
23
  return []
24
- return sorted(f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv"))
25
 
26
 
27
  def pick_device(ui_choice: str) -> str:
28
- if (ui_choice or "").startswith("cuda") and torch.cuda.is_available():
29
  return "cuda"
30
  return "cpu"
31
 
32
 
33
- def make_sample_series(n, seed, trend, season_period, season_amp, noise):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  rng = np.random.default_rng(int(seed))
35
- t = np.arange(int(n))
36
  y = (
37
  float(trend) * t
38
  + float(season_amp) * np.sin(2 * np.pi * t / max(1, int(season_period)))
39
- + rng.normal(0.0, float(noise), size=len(t))
40
- )
41
- # shift up if negative to keep plots nice
42
- mn = float(np.min(y))
43
- if mn < 0:
44
- y = y - mn
45
- return y.astype(np.float32)
46
 
47
 
48
- def load_series_from_csv(path_or_file, column=None):
49
- df = pd.read_csv(path_or_file)
50
  if df.shape[1] == 0:
51
  raise ValueError("CSV vuoto o non leggibile.")
52
 
53
  col = (column or "").strip()
54
- if col == "":
55
- # try native numeric dtypes first
56
  numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
57
- # fallback: try coercion
58
  if not numeric_cols:
 
59
  for c in df.columns:
60
  coerced = pd.to_numeric(df[c], errors="coerce")
61
- if coerced.notna().sum() >= 10:
62
  numeric_cols.append(c)
63
- if not numeric_cols:
64
- raise ValueError("Nessuna colonna numerica nel CSV. Specifica la colonna corretta.")
65
  col = numeric_cols[0]
66
 
67
  if col not in df.columns:
68
- raise ValueError(f"Colonna '{col}' non trovata. Colonne: {list(df.columns)}")
69
 
70
- y = pd.to_numeric(df[col], errors="coerce").dropna().to_numpy()
71
  if len(y) < 10:
72
- raise ValueError("Serie troppo corta (minimo ~10 punti dopo dropna).")
73
 
74
- return y.astype(np.float32), col
75
 
76
 
77
  # =========================
78
- # Pipeline cache
79
  # =========================
80
  _PIPELINE = None
81
- _PIPELINE_META = {}
82
 
83
 
84
- def get_pipeline(model_id: str, device: str):
85
  global _PIPELINE, _PIPELINE_META
86
 
87
  model_id = (model_id or MODEL_ID_DEFAULT).strip()
88
- device = "cuda" if device == "cuda" and torch.cuda.is_available() else "cpu"
89
-
90
- if (
91
- _PIPELINE is None
92
- or _PIPELINE_META.get("model_id") != model_id
93
- or _PIPELINE_META.get("device") != device
94
- ):
95
- _PIPELINE = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
96
  _PIPELINE_META = {"model_id": model_id, "device": device}
97
 
98
  return _PIPELINE
99
 
100
 
101
  # =========================
102
- # Chronos-2 predict_df helpers
103
  # =========================
104
- def build_context_df(y: np.ndarray, freq: str = "D"):
105
- """
106
- Build a minimal context DataFrame compatible with Chronos2Pipeline.predict_df().
107
- We generate a synthetic timestamp index so it works for Sample and numeric-only CSV.
108
- """
109
- ts = pd.date_range("2000-01-01", periods=len(y), freq=freq)
110
- return pd.DataFrame({"id": "series_0", "timestamp": ts, "target": y})
111
-
112
-
113
- def pick_quantile_column(pred_df: pd.DataFrame, q: float) -> str:
114
- """
115
- Column naming can vary. We robustly find a column representing quantile q.
116
- Common patterns: "0.1", "0.5", "0.9" OR "q0.1" OR "quantile_0.1" etc.
117
- """
118
- q = float(q)
119
- # direct numeric-string match
120
- for c in pred_df.columns:
121
- try:
122
- if abs(float(c) - q) < 1e-9:
123
- return c
124
- except Exception:
125
- pass
126
-
127
- # prefixed patterns
128
- candidates = []
129
- for c in pred_df.columns:
130
- lc = str(c).lower()
131
- if "quant" in lc or lc.startswith("q"):
132
- # try to extract float from tail
133
- for token in [lc.replace("quantile", "").replace("_", ""), lc.replace("q", "")]:
134
- try:
135
- if abs(float(token) - q) < 1e-9:
136
- candidates.append(c)
137
- except Exception:
138
- pass
139
-
140
- if candidates:
141
- return candidates[0]
142
-
143
- raise ValueError(
144
- f"Non riesco a trovare la colonna del quantile {q}. "
145
- f"Colonne disponibili: {list(pred_df.columns)}"
 
 
 
 
 
 
 
 
 
146
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
 
149
  # =========================
150
- # Forecast core
151
  # =========================
152
- def run_forecast(
153
- input_mode,
154
- test_csv_name,
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  upload_csv,
156
- csv_column,
157
- n,
158
- seed,
159
- trend,
160
- season_period,
161
- season_amp,
162
- noise,
163
- prediction_length,
164
- q_low,
165
- q_high,
166
- device_ui,
167
- model_id,
168
- ):
169
- q_low = float(q_low)
170
- q_high = float(q_high)
 
 
 
 
 
 
 
 
 
 
171
  if q_low >= q_high:
172
  raise gr.Error("Quantile low deve essere < quantile high.")
173
 
174
  device = pick_device(device_ui)
175
- pipe = get_pipeline(model_id, device)
176
 
177
- # 1) pick data
 
 
178
  if input_mode == "Test CSV":
179
  if not test_csv_name:
180
- raise gr.Error("Seleziona un file nella dropdown dei Test CSV.")
181
- path = os.path.join(DATA_DIR, test_csv_name)
182
- if not os.path.exists(path):
183
- raise gr.Error(f"Non trovo {path}. Assicurati che sia nel repo.")
184
- y, used_col = load_series_from_csv(path, csv_column)
185
- source = f"Test CSV: {test_csv_name} ({used_col})"
186
 
187
  elif input_mode == "Upload CSV":
188
  if upload_csv is None:
189
- raise gr.Error("Carica un CSV per usare la modalità Upload.")
190
- y, used_col = load_series_from_csv(upload_csv.name, csv_column)
191
- source = f"Upload CSV ({used_col})"
192
-
193
- else: # Sample
194
- y = make_sample_series(n, seed, trend, season_period, season_amp, noise)
195
- source = "Sample data"
196
-
197
- # 2) build context df (single series)
198
- context_df = build_context_df(y, freq="D")
 
 
 
 
 
 
 
 
199
 
200
- # 3) predict quantiles via predict_df (stable API per chronos-2)
201
- quantiles = sorted({q_low, 0.5, q_high})
202
- pred_df = pipe.predict_df(
203
- context_df,
 
204
  prediction_length=int(prediction_length),
205
- quantile_levels=quantiles,
206
- id_column="id",
207
- timestamp_column="timestamp",
208
- target="target",
209
  )
 
210
 
211
- # 4) extract arrays
212
- col_low = pick_quantile_column(pred_df, q_low)
213
- col_med = pick_quantile_column(pred_df, 0.5)
214
- col_high = pick_quantile_column(pred_df, q_high)
215
-
216
- # pred_df contains the forecast horizon rows; keep only series_0
217
- pred_df = pred_df[pred_df["id"] == "series_0"].copy()
218
-
219
- ts_fcst = pd.to_datetime(pred_df["timestamp"]).to_numpy()
220
- low = pred_df[col_low].to_numpy(dtype=np.float32)
221
- median = pred_df[col_med].to_numpy(dtype=np.float32)
222
- high = pred_df[col_high].to_numpy(dtype=np.float32)
223
 
224
- # 5) plot (use integer axis for simplicity)
225
- t_hist = np.arange(len(y))
226
  t_fcst = np.arange(len(y), len(y) + int(prediction_length))
227
-
228
- fig, ax = plt.subplots(figsize=(10, 4))
229
- ax.plot(t_hist, y, label="history")
230
- ax.plot(t_fcst, median, label="forecast (median)")
231
- ax.fill_between(t_fcst, low, high, alpha=0.25, label=f"band [{q_low:.2f}, {q_high:.2f}]")
232
- ax.axvline(len(y) - 1, linestyle="--", linewidth=1)
233
- ax.set_title(source)
234
- ax.set_xlabel("t")
235
- ax.set_ylabel("value")
236
- ax.grid(True, alpha=0.3)
237
- ax.legend()
238
-
239
- # 6) output table + downloadable csv
240
- out_df = pd.DataFrame(
241
- {
242
- "t": t_fcst,
243
- "timestamp": ts_fcst,
244
- "median": median,
245
- f"q{q_low:.2f}": low,
246
- f"q{q_high:.2f}": high,
247
- }
248
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
- out_path = "/tmp/chronos2_forecast.csv"
251
- out_df.to_csv(out_path, index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  info = {
254
- "model_id": (model_id or MODEL_ID_DEFAULT),
255
- "device": device,
256
- "source": source,
257
  "history_points": int(len(y)),
258
  "prediction_length": int(prediction_length),
259
- "quantile_levels": quantiles,
260
- "pred_df_columns": list(out_df.columns),
 
 
 
 
261
  }
262
 
263
- return fig, out_df, out_path, info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
 
266
  # =========================
267
  # UI
268
  # =========================
269
- with gr.Blocks(title="Chronos-2 • HF Spaces Demo") as demo:
 
 
 
 
 
270
  gr.Markdown(
271
- "# ⏱️ Chronos-2 Forecast Demo (HF Spaces)\n"
272
- "- **Sample**: genera una serie sintetica\n"
273
- "- **Test CSV**: usa file in `data/`\n"
274
- "- **Upload CSV**: carica un tuo CSV\n\n"
275
- "Questa versione usa **predict_df()** (API consigliata per Chronos-2) e calcola direttamente i **quantili**. "
276
  )
277
 
278
  with gr.Row():
279
- input_mode = gr.Radio(["Sample", "Test CSV", "Upload CSV"], value="Sample", label="Input source")
280
- device_ui = gr.Dropdown(
281
- ["cpu", "cuda (se disponibile)"],
282
- value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu",
283
- label="Device",
284
- )
285
- model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID")
286
-
287
- with gr.Row():
288
- test_csv_name = gr.Dropdown(choices=available_test_csv(), label="Test CSV disponibili (data/)")
289
- upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
290
- csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value")
291
-
292
- with gr.Accordion("Sample data settings", open=False):
293
- n = gr.Slider(60, 600, 220, step=10, label="History length")
294
- seed = gr.Number(42, precision=0, label="Seed")
295
- trend = gr.Slider(0.0, 0.2, 0.03, step=0.005, label="Trend")
296
- season_period = gr.Slider(2, 90, 14, step=1, label="Season period")
297
- season_amp = gr.Slider(0.0, 10.0, 3.0, step=0.1, label="Season amplitude")
298
- noise = gr.Slider(0.0, 5.0, 0.8, step=0.05, label="Noise")
299
-
300
- with gr.Accordion("Forecast settings", open=True):
301
- prediction_length = gr.Slider(1, 180, 30, step=1, label="Prediction length")
302
- q_low = gr.Slider(0.01, 0.49, 0.10, step=0.01, label="Quantile low")
303
- q_high = gr.Slider(0.51, 0.99, 0.90, step=0.01, label="Quantile high")
304
-
305
- run_btn = gr.Button("Run forecast", variant="primary")
306
-
307
- plot = gr.Plot(label="Forecast")
308
- table = gr.Dataframe(label="Forecast values", interactive=False)
309
- download = gr.File(label="Download CSV")
310
- info = gr.JSON(label="Run info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  run_btn.click(
313
- fn=run_forecast,
314
  inputs=[
315
- input_mode,
316
- test_csv_name,
317
- upload_csv,
318
- csv_column,
319
- n,
320
- seed,
321
- trend,
322
- season_period,
323
- season_amp,
324
- noise,
325
- prediction_length,
326
- q_low,
327
- q_high,
328
- device_ui,
329
- model_id,
330
  ],
331
- outputs=[plot, table, download, info],
332
  )
333
 
334
  demo.queue()
 
1
  import os
2
+ import math
3
+ import time
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple
6
+
7
  import numpy as np
8
  import pandas as pd
9
  import gradio as gr
 
10
  import torch
11
+ import plotly.graph_objects as go
12
 
13
  from chronos import Chronos2Pipeline
14
 
 
18
  # =========================
19
  MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
20
  DATA_DIR = "data"
21
+ OUT_DIR = "/tmp"
22
 
23
 
24
  # =========================
25
+ # Utilities
26
  # =========================
27
+ def available_test_csv() -> list[str]:
28
  if not os.path.isdir(DATA_DIR):
29
  return []
30
+ return sorted([f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv")])
31
 
32
 
33
  def pick_device(ui_choice: str) -> str:
34
+ if ui_choice and ui_choice.startswith("cuda") and torch.cuda.is_available():
35
  return "cuda"
36
  return "cpu"
37
 
38
 
39
+ def safe_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
40
+ denom = np.maximum(1e-8, np.abs(y_true))
41
+ return float(np.mean(np.abs((y_true - y_pred) / denom)) * 100.0)
42
+
43
+
44
+ def rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
45
+ return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))
46
+
47
+
48
+ def mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
49
+ return float(np.mean(np.abs(y_true - y_pred)))
50
+
51
+
52
+ def coverage(y_true: np.ndarray, low: np.ndarray, high: np.ndarray) -> float:
53
+ inside = (y_true >= low) & (y_true <= high)
54
+ return float(np.mean(inside) * 100.0)
55
+
56
+
57
+ def interval_width(low: np.ndarray, high: np.ndarray) -> float:
58
+ return float(np.mean(high - low))
59
+
60
+
61
+ def format_kpi(label: str, value: str, hint: str = "") -> str:
62
+ # Simple “card” layout via HTML
63
+ hint_html = f"<div style='opacity:.75;font-size:12px;margin-top:6px;'>{hint}</div>" if hint else ""
64
+ return f"""
65
+ <div style="
66
+ border:1px solid rgba(255,255,255,.12);
67
+ border-radius:16px;
68
+ padding:14px 16px;
69
+ background: rgba(255,255,255,.04);
70
+ backdrop-filter: blur(6px);
71
+ ">
72
+ <div style="font-size:12px;opacity:.8;">{label}</div>
73
+ <div style="font-size:22px;font-weight:700;margin-top:4px;">{value}</div>
74
+ {hint_html}
75
+ </div>
76
+ """
77
+
78
+
79
+ def make_sample_series(
80
+ n: int,
81
+ seed: int,
82
+ trend: float,
83
+ season_period: int,
84
+ season_amp: float,
85
+ noise: float,
86
+ positive_shift: bool = True,
87
+ ) -> np.ndarray:
88
  rng = np.random.default_rng(int(seed))
89
+ t = np.arange(int(n), dtype=np.float32)
90
  y = (
91
  float(trend) * t
92
  + float(season_amp) * np.sin(2 * np.pi * t / max(1, int(season_period)))
93
+ + rng.normal(0.0, float(noise), size=int(n))
94
+ ).astype(np.float32)
95
+
96
+ if positive_shift and float(np.min(y)) < 0:
97
+ y = y - float(np.min(y))
98
+ return y
 
99
 
100
 
101
+ def load_series_from_csv(csv_path: str, column: Optional[str]) -> Tuple[np.ndarray, str, pd.DataFrame]:
102
+ df = pd.read_csv(csv_path)
103
  if df.shape[1] == 0:
104
  raise ValueError("CSV vuoto o non leggibile.")
105
 
106
  col = (column or "").strip()
107
+ if not col:
 
108
  numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
 
109
  if not numeric_cols:
110
+ # Try coercion: maybe numeric stored as strings
111
  for c in df.columns:
112
  coerced = pd.to_numeric(df[c], errors="coerce")
113
+ if coerced.notna().sum() > 0:
114
  numeric_cols.append(c)
115
+ if not numeric_cols:
116
+ raise ValueError("Non trovo colonne numeriche nel CSV.")
117
  col = numeric_cols[0]
118
 
119
  if col not in df.columns:
120
+ raise ValueError(f"Colonna '{col}' non trovata. Disponibili: {list(df.columns)}")
121
 
122
+ y = pd.to_numeric(df[col], errors="coerce").dropna().astype(np.float32).to_numpy()
123
  if len(y) < 10:
124
+ raise ValueError("Serie troppo corta (minimo consigliato: 10 punti).")
125
 
126
+ return y, col, df
127
 
128
 
129
  # =========================
130
+ # Model cache
131
  # =========================
132
  _PIPELINE = None
133
+ _PIPELINE_META = {"model_id": None, "device": None}
134
 
135
 
136
+ def get_pipeline(model_id: str, device: str) -> Chronos2Pipeline:
137
  global _PIPELINE, _PIPELINE_META
138
 
139
  model_id = (model_id or MODEL_ID_DEFAULT).strip()
140
+ device = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu"
141
+
142
+ if _PIPELINE is None or _PIPELINE_META["model_id"] != model_id or _PIPELINE_META["device"] != device:
143
+ pipe = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
144
+ _PIPELINE = pipe
 
 
 
145
  _PIPELINE_META = {"model_id": model_id, "device": device}
146
 
147
  return _PIPELINE
148
 
149
 
150
  # =========================
151
+ # Plotly helpers
152
  # =========================
153
+ def plot_forecast_interactive(
154
+ y: np.ndarray,
155
+ median: np.ndarray,
156
+ low: np.ndarray,
157
+ high: np.ndarray,
158
+ title: str,
159
+ q_low: float,
160
+ q_high: float,
161
+ ) -> go.Figure:
162
+ t_hist = np.arange(len(y))
163
+ t_fcst = np.arange(len(y), len(y) + len(median))
164
+
165
+ fig = go.Figure()
166
+
167
+ fig.add_trace(go.Scatter(
168
+ x=t_hist, y=y, mode="lines",
169
+ name="History",
170
+ hovertemplate="t=%{x}<br>y=%{y:.4f}<extra></extra>"
171
+ ))
172
+
173
+ # Upper bound (invisible line), then lower bound with fill to create band
174
+ fig.add_trace(go.Scatter(
175
+ x=t_fcst, y=high, mode="lines",
176
+ name="Upper",
177
+ line=dict(width=0),
178
+ showlegend=False,
179
+ hoverinfo="skip"
180
+ ))
181
+ fig.add_trace(go.Scatter(
182
+ x=t_fcst, y=low, mode="lines",
183
+ name=f"Band [{q_low:.2f}, {q_high:.2f}]",
184
+ fill="tonexty",
185
+ line=dict(width=0),
186
+ hovertemplate="t=%{x}<br>low=%{y:.4f}<extra></extra>"
187
+ ))
188
+
189
+ fig.add_trace(go.Scatter(
190
+ x=t_fcst, y=median, mode="lines",
191
+ name="Forecast (median)",
192
+ hovertemplate="t=%{x}<br>median=%{y:.4f}<extra></extra>"
193
+ ))
194
+
195
+ fig.add_vline(x=len(y) - 1, line_width=1, line_dash="dash", opacity=0.6)
196
+
197
+ fig.update_layout(
198
+ title=title,
199
+ hovermode="x unified",
200
+ margin=dict(l=10, r=10, t=55, b=10),
201
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
202
+ xaxis_title="t",
203
+ yaxis_title="value",
204
  )
205
+ return fig
206
+
207
+
208
+ def plot_backtest_interactive(
209
+ y_train: np.ndarray,
210
+ y_true: np.ndarray,
211
+ pred_median: np.ndarray,
212
+ low: np.ndarray,
213
+ high: np.ndarray,
214
+ q_low: float,
215
+ q_high: float,
216
+ ) -> go.Figure:
217
+ t_train = np.arange(len(y_train))
218
+ t_test = np.arange(len(y_train), len(y_train) + len(y_true))
219
+
220
+ fig = go.Figure()
221
+ fig.add_trace(go.Scatter(x=t_train, y=y_train, mode="lines", name="Train"))
222
+ fig.add_trace(go.Scatter(x=t_test, y=y_true, mode="lines", name="True (holdout)"))
223
+
224
+ fig.add_trace(go.Scatter(x=t_test, y=high, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip"))
225
+ fig.add_trace(go.Scatter(
226
+ x=t_test, y=low, mode="lines",
227
+ fill="tonexty", line=dict(width=0),
228
+ name=f"Band [{q_low:.2f}, {q_high:.2f}]"
229
+ ))
230
+ fig.add_trace(go.Scatter(x=t_test, y=pred_median, mode="lines", name="Pred (median)"))
231
+
232
+ fig.add_vline(x=len(y_train) - 1, line_width=1, line_dash="dash", opacity=0.6)
233
+
234
+ fig.update_layout(
235
+ title="Backtest (holdout) — interactive",
236
+ hovermode="x unified",
237
+ margin=dict(l=10, r=10, t=55, b=10),
238
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
239
+ xaxis_title="t",
240
+ yaxis_title="value",
241
+ )
242
+ return fig
243
+
244
+
245
+ def plot_sample_distribution(samples: np.ndarray) -> go.Figure:
246
+ # show distribution for a few horizons
247
+ if samples.ndim != 2:
248
+ samples = np.asarray(samples)
249
+ n_h = samples.shape[1]
250
+ idxs = []
251
+ for frac in [0.1, 0.5, 0.9]:
252
+ i = int(round((n_h - 1) * frac))
253
+ idxs.append(i)
254
+ idxs = sorted(set(idxs))
255
+
256
+ fig = go.Figure()
257
+ for i in idxs:
258
+ fig.add_trace(go.Histogram(
259
+ x=samples[:, i],
260
+ name=f"h={i+1}",
261
+ opacity=0.6
262
+ ))
263
+ fig.update_layout(
264
+ barmode="overlay",
265
+ title="Forecast sample distributions (selected horizons)",
266
+ margin=dict(l=10, r=10, t=55, b=10),
267
+ xaxis_title="value",
268
+ yaxis_title="count",
269
+ )
270
+ return fig
271
 
272
 
273
  # =========================
274
+ # Core run
275
  # =========================
276
+ @dataclass
277
+ class RunResult:
278
+ forecast_fig: go.Figure
279
+ backtest_fig: Optional[go.Figure]
280
+ dist_fig: go.Figure
281
+ forecast_table: pd.DataFrame
282
+ backtest_table: Optional[pd.DataFrame]
283
+ forecast_csv_path: str
284
+ backtest_csv_path: Optional[str]
285
+ kpi_html: str
286
+ info: dict
287
+
288
+
289
+ def run_dashboard(
290
+ input_mode: str,
291
+ test_csv_name: str,
292
  upload_csv,
293
+ csv_column: str,
294
+
295
+ # sample params
296
+ n: int,
297
+ seed: int,
298
+ trend: float,
299
+ season_period: int,
300
+ season_amp: float,
301
+ noise: float,
302
+
303
+ # forecast params
304
+ prediction_length: int,
305
+ num_samples: int,
306
+ q_low: float,
307
+ q_high: float,
308
+
309
+ # backtest
310
+ do_backtest: bool,
311
+ holdout: int,
312
+
313
+ # system
314
+ device_ui: str,
315
+ model_id: str,
316
+ ) -> RunResult:
317
+
318
  if q_low >= q_high:
319
  raise gr.Error("Quantile low deve essere < quantile high.")
320
 
321
  device = pick_device(device_ui)
 
322
 
323
+ # ---------
324
+ # Load series
325
+ # ---------
326
  if input_mode == "Test CSV":
327
  if not test_csv_name:
328
+ raise gr.Error("Seleziona un Test CSV dalla dropdown.")
329
+ csv_path = os.path.join(DATA_DIR, test_csv_name)
330
+ if not os.path.exists(csv_path):
331
+ raise gr.Error(f"File non trovato: {csv_path}")
332
+ y, used_col, df_preview = load_series_from_csv(csv_path, csv_column)
333
+ source_title = f"Test CSV: {test_csv_name} • col={used_col}"
334
 
335
  elif input_mode == "Upload CSV":
336
  if upload_csv is None:
337
+ raise gr.Error("Carica un CSV (Upload CSV) oppure cambia modalità.")
338
+ y, used_col, df_preview = load_series_from_csv(upload_csv.name, csv_column)
339
+ source_title = f"Upload CSV • col={used_col}"
340
+
341
+ else:
342
+ y = make_sample_series(n, seed, trend, season_period, season_amp, noise, positive_shift=True)
343
+ df_preview = pd.DataFrame({"value": y})
344
+ used_col = "value"
345
+ source_title = "Sample series"
346
+
347
+ if do_backtest and holdout >= len(y):
348
+ raise gr.Error("Holdout deve essere più piccolo della lunghezza dello storico.")
349
+
350
+ # ---------
351
+ # Model
352
+ # ---------
353
+ t0 = time.time()
354
+ pipe = get_pipeline(model_id, device)
355
 
356
+ # ---------
357
+ # Forecast
358
+ # ---------
359
+ samples = pipe.predict(
360
+ context=y.tolist(),
361
  prediction_length=int(prediction_length),
362
+ num_samples=int(num_samples),
 
 
 
363
  )
364
+ samples = np.asarray(samples, dtype=np.float32)
365
 
366
+ median = np.quantile(samples, 0.50, axis=0)
367
+ low = np.quantile(samples, float(q_low), axis=0)
368
+ high = np.quantile(samples, float(q_high), axis=0)
 
 
 
 
 
 
 
 
 
369
 
370
+ # Tables
 
371
  t_fcst = np.arange(len(y), len(y) + int(prediction_length))
372
+ forecast_df = pd.DataFrame({
373
+ "t": t_fcst,
374
+ "median": median,
375
+ f"q{q_low:.2f}": low,
376
+ f"q{q_high:.2f}": high,
377
+ })
378
+
379
+ forecast_csv_path = os.path.join(OUT_DIR, "chronos2_forecast.csv")
380
+ forecast_df.to_csv(forecast_csv_path, index=False)
381
+
382
+ # Plots
383
+ forecast_fig = plot_forecast_interactive(
384
+ y=y,
385
+ median=median,
386
+ low=low,
387
+ high=high,
388
+ title=f"Forecast — {source_title}",
389
+ q_low=q_low,
390
+ q_high=q_high,
 
 
391
  )
392
+ dist_fig = plot_sample_distribution(samples)
393
+
394
+ # ---------
395
+ # Backtest (optional)
396
+ # ---------
397
+ backtest_fig = None
398
+ backtest_df = None
399
+ backtest_csv_path = None
400
+
401
+ kpi_items = []
402
+ # Always show run/system KPIs
403
+ elapsed = time.time() - t0
404
+ kpi_items.append(format_kpi("Device", device.upper(), f"torch.cuda={torch.cuda.is_available()}"))
405
+ kpi_items.append(format_kpi("Model", (model_id or MODEL_ID_DEFAULT), "Chronos-2 pipeline"))
406
+ kpi_items.append(format_kpi("Latency", f"{elapsed:.2f}s", "model load cached after first run"))
407
+ kpi_items.append(format_kpi("Samples", f"{int(num_samples)}", "more = smoother quantiles"))
408
+
409
+ # Coverage/width (forecast only) – informational
410
+ kpi_items.append(format_kpi("Interval", f"[{q_low:.2f}, {q_high:.2f}]", "uncertainty band"))
411
+ kpi_items.append(format_kpi("Avg band width", f"{interval_width(low, high):.3f}", "forecast band only"))
412
+
413
+ if do_backtest:
414
+ y_train = y[:-int(holdout)]
415
+ y_true = y[-int(holdout):]
416
+
417
+ bt_samples = pipe.predict(
418
+ context=y_train.tolist(),
419
+ prediction_length=int(holdout),
420
+ num_samples=int(num_samples),
421
+ )
422
+ bt_samples = np.asarray(bt_samples, dtype=np.float32)
423
+ bt_median = np.quantile(bt_samples, 0.50, axis=0)
424
+ bt_low = np.quantile(bt_samples, float(q_low), axis=0)
425
+ bt_high = np.quantile(bt_samples, float(q_high), axis=0)
426
+
427
+ # Metrics
428
+ bt_mae = mae(y_true, bt_median)
429
+ bt_rmse = rmse(y_true, bt_median)
430
+ bt_mape = safe_mape(y_true, bt_median)
431
+ bt_cov = coverage(y_true, bt_low, bt_high)
432
+ bt_w = interval_width(bt_low, bt_high)
433
+
434
+ kpi_items.append(format_kpi("Backtest MAE", f"{bt_mae:.3f}", f"holdout={holdout}"))
435
+ kpi_items.append(format_kpi("Backtest RMSE", f"{bt_rmse:.3f}", ""))
436
+ kpi_items.append(format_kpi("Backtest MAPE", f"{bt_mape:.2f}%", ""))
437
+ kpi_items.append(format_kpi("Coverage", f"{bt_cov:.1f}%", "inside band"))
438
+ kpi_items.append(format_kpi("Backtest width", f"{bt_w:.3f}", "avg band width"))
439
+
440
+ backtest_fig = plot_backtest_interactive(
441
+ y_train=y_train,
442
+ y_true=y_true,
443
+ pred_median=bt_median,
444
+ low=bt_low,
445
+ high=bt_high,
446
+ q_low=q_low,
447
+ q_high=q_high,
448
+ )
449
 
450
+ t_test = np.arange(len(y_train), len(y_train) + int(holdout))
451
+ backtest_df = pd.DataFrame({
452
+ "t": t_test,
453
+ "true": y_true,
454
+ "pred_median": bt_median,
455
+ f"q{q_low:.2f}": bt_low,
456
+ f"q{q_high:.2f}": bt_high,
457
+ })
458
+ backtest_csv_path = os.path.join(OUT_DIR, "chronos2_backtest.csv")
459
+ backtest_df.to_csv(backtest_csv_path, index=False)
460
+
461
+ kpi_html = f"""
462
+ <div style="display:grid; grid-template-columns: repeat(6, minmax(0, 1fr)); gap:12px;">
463
+ {''.join(kpi_items)}
464
+ </div>
465
+ """
466
 
467
  info = {
468
+ "source": source_title,
 
 
469
  "history_points": int(len(y)),
470
  "prediction_length": int(prediction_length),
471
+ "num_samples": int(num_samples),
472
+ "q_low": float(q_low),
473
+ "q_high": float(q_high),
474
+ "backtest": bool(do_backtest),
475
+ "holdout": int(holdout) if do_backtest else None,
476
+ "column_used": used_col,
477
  }
478
 
479
+ return RunResult(
480
+ forecast_fig=forecast_fig,
481
+ backtest_fig=backtest_fig,
482
+ dist_fig=dist_fig,
483
+ forecast_table=forecast_df,
484
+ backtest_table=backtest_df,
485
+ forecast_csv_path=forecast_csv_path,
486
+ backtest_csv_path=backtest_csv_path,
487
+ kpi_html=kpi_html,
488
+ info=info,
489
+ )
490
+
491
+
492
+ def run_dashboard_wrapped(*args):
493
+ res = run_dashboard(*args)
494
+ # outputs must be basic objects
495
+ # If no backtest, send an empty placeholder plot and empty table/file
496
+ empty_fig = go.Figure().update_layout(
497
+ title="Backtest disabled",
498
+ margin=dict(l=10, r=10, t=55, b=10),
499
+ )
500
+ empty_df = pd.DataFrame()
501
+
502
+ return (
503
+ res.kpi_html,
504
+ res.forecast_fig,
505
+ (res.backtest_fig if res.backtest_fig is not None else empty_fig),
506
+ res.dist_fig,
507
+ res.forecast_table,
508
+ (res.backtest_table if res.backtest_table is not None else empty_df),
509
+ res.forecast_csv_path,
510
+ (res.backtest_csv_path if res.backtest_csv_path is not None else None),
511
+ res.info,
512
+ )
513
 
514
 
515
  # =========================
516
  # UI
517
  # =========================
518
+ css = """
519
+ :root { --radius: 18px; }
520
+ .gradio-container { max-width: 1200px !important; }
521
+ """
522
+
523
+ with gr.Blocks(title="Chronos-2 • Forecast Dashboard", css=css) as demo:
524
  gr.Markdown(
525
+ """
526
+ # ⏱️ Chronos-2 Forecast Dashboard
527
+ Una dashboard interattiva (Plotly) per testare **Amazon Chronos-2** su serie storiche (sample / CSV / upload), con **bande di incertezza** e **backtest**.
528
+ """
 
529
  )
530
 
531
  with gr.Row():
532
+ with gr.Column(scale=1, min_width=360):
533
+ gr.Markdown("## Input")
534
+
535
+ input_mode = gr.Radio(
536
+ ["Sample", "Test CSV", "Upload CSV"],
537
+ value="Sample",
538
+ label="Sorgente dati",
539
+ )
540
+
541
+ test_csv_name = gr.Dropdown(
542
+ choices=available_test_csv(),
543
+ value=None,
544
+ label="Test CSV (cartella data/)",
545
+ info="Comparirà qui se metti .csv dentro data/",
546
+ )
547
+
548
+ upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
549
+ csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value")
550
+
551
+ gr.Markdown("## Sistema")
552
+ device_ui = gr.Dropdown(
553
+ ["cpu", "cuda (se disponibile)"],
554
+ value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu",
555
+ label="Device",
556
+ )
557
+ model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID")
558
+
559
+ with gr.Accordion("Sample generator", open=False):
560
+ n = gr.Slider(60, 1200, value=300, step=10, label="History length")
561
+ seed = gr.Number(value=42, precision=0, label="Seed")
562
+ trend = gr.Slider(0.0, 0.2, value=0.03, step=0.005, label="Trend")
563
+ season_period = gr.Slider(2, 120, value=14, step=1, label="Season period")
564
+ season_amp = gr.Slider(0.0, 12.0, value=3.0, step=0.1, label="Season amplitude")
565
+ noise = gr.Slider(0.0, 6.0, value=0.8, step=0.05, label="Noise")
566
+
567
+ gr.Markdown("## Forecast settings")
568
+ prediction_length = gr.Slider(1, 365, value=30, step=1, label="Prediction length")
569
+ num_samples = gr.Slider(50, 800, value=300, step=25, label="Num samples (quantili più stabili)")
570
+ q_low = gr.Slider(0.01, 0.49, value=0.10, step=0.01, label="Quantile low")
571
+ q_high = gr.Slider(0.51, 0.99, value=0.90, step=0.01, label="Quantile high")
572
+
573
+ gr.Markdown("## Backtest")
574
+ do_backtest = gr.Checkbox(value=True, label="Esegui backtest holdout")
575
+ holdout = gr.Slider(5, 240, value=30, step=1, label="Holdout points")
576
+
577
+ run_btn = gr.Button("Run", variant="primary")
578
+
579
+ with gr.Column(scale=2):
580
+ gr.Markdown("## KPI")
581
+ kpi_html = gr.HTML()
582
+
583
+ with gr.Tabs():
584
+ with gr.Tab("Forecast"):
585
+ forecast_plot = gr.Plot(label="Interactive forecast (Plotly)")
586
+ forecast_table = gr.Dataframe(label="Forecast table", interactive=False)
587
+
588
+ with gr.Tab("Backtest"):
589
+ backtest_plot = gr.Plot(label="Interactive backtest (Plotly)")
590
+ backtest_table = gr.Dataframe(label="Backtest table", interactive=False)
591
+
592
+ with gr.Tab("Distributions"):
593
+ dist_plot = gr.Plot(label="Sample distributions (selected horizons)")
594
+
595
+ with gr.Tab("Export"):
596
+ gr.Markdown("Scarica i CSV prodotti dall’ultima run:")
597
+ forecast_download = gr.File(label="Forecast CSV")
598
+ backtest_download = gr.File(label="Backtest CSV")
599
+
600
+ with gr.Tab("Run info"):
601
+ run_info = gr.JSON(label="Info")
602
 
603
  run_btn.click(
604
+ fn=run_dashboard_wrapped,
605
  inputs=[
606
+ input_mode, test_csv_name, upload_csv, csv_column,
607
+ n, seed, trend, season_period, season_amp, noise,
608
+ prediction_length, num_samples, q_low, q_high,
609
+ do_backtest, holdout,
610
+ device_ui, model_id,
611
+ ],
612
+ outputs=[
613
+ kpi_html,
614
+ forecast_plot, backtest_plot, dist_plot,
615
+ forecast_table, backtest_table,
616
+ forecast_download, backtest_download,
617
+ run_info,
 
 
 
618
  ],
 
619
  )
620
 
621
  demo.queue()