bosh94 commited on
Commit
de1701c
·
verified ·
1 Parent(s): 4a4dcbd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -437
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  import time
3
  import inspect
4
- from dataclasses import dataclass
5
- from typing import Optional, Tuple, Any, Dict, List
6
 
7
  import numpy as np
8
  import pandas as pd
@@ -13,17 +12,14 @@ import plotly.graph_objects as go
13
  from chronos import Chronos2Pipeline
14
 
15
 
16
- # =========================
17
- # Config
18
- # =========================
19
  MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
20
  DATA_DIR = "data"
21
  OUT_DIR = "/tmp"
22
 
23
 
24
- # =========================
25
- # Data helpers
26
- # =========================
27
  def available_test_csv() -> List[str]:
28
  if not os.path.isdir(DATA_DIR):
29
  return []
@@ -31,42 +27,25 @@ def available_test_csv() -> List[str]:
31
 
32
 
33
  def pick_device(ui_choice: str) -> str:
34
- if (ui_choice or "").startswith("cuda") and torch.cuda.is_available():
35
- return "cuda"
36
- return "cpu"
37
-
38
-
39
- def make_sample_series(
40
- n: int,
41
- seed: int,
42
- trend: float,
43
- season_period: int,
44
- season_amp: float,
45
- noise: float,
46
- ) -> np.ndarray:
47
  rng = np.random.default_rng(int(seed))
48
  t = np.arange(int(n), dtype=np.float32)
49
- y = (
50
- float(trend) * t
51
- + float(season_amp) * np.sin(2 * np.pi * t / max(1, int(season_period)))
52
- + rng.normal(0.0, float(noise), size=int(n))
53
- ).astype(np.float32)
54
  if float(np.min(y)) < 0:
55
- y = y - float(np.min(y))
56
  return y
57
 
58
 
59
- def load_series_from_csv(csv_path: str, column: Optional[str]) -> Tuple[np.ndarray, str, pd.DataFrame]:
60
  df = pd.read_csv(csv_path)
61
- if df.shape[1] == 0:
62
- raise ValueError("CSV vuoto o non leggibile.")
63
-
64
  col = (column or "").strip()
65
  if not col:
66
- # numeric columns first
67
  numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
68
  if not numeric_cols:
69
- # try coercion (strings -> numbers)
70
  for c in df.columns:
71
  coerced = pd.to_numeric(df[c], errors="coerce")
72
  if coerced.notna().sum() > 0:
@@ -74,62 +53,34 @@ def load_series_from_csv(csv_path: str, column: Optional[str]) -> Tuple[np.ndarr
74
  if not numeric_cols:
75
  raise ValueError("Non trovo colonne numeriche nel CSV.")
76
  col = numeric_cols[0]
77
-
78
  if col not in df.columns:
79
  raise ValueError(f"Colonna '{col}' non trovata. Disponibili: {list(df.columns)}")
80
-
81
  y = pd.to_numeric(df[col], errors="coerce").dropna().astype(np.float32).to_numpy()
82
  if len(y) < 10:
83
- raise ValueError("Serie troppo corta (minimo consigliato: 10 punti).")
84
- return y, col, df
85
-
86
-
87
- # =========================
88
- # Metrics
89
- # =========================
90
- def mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
91
- return float(np.mean(np.abs(y_true - y_pred)))
92
-
93
-
94
- def rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
95
- return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))
96
-
97
-
98
- def mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
99
- denom = np.maximum(1e-8, np.abs(y_true))
100
- return float(np.mean(np.abs((y_true - y_pred) / denom)) * 100.0)
101
-
102
-
103
- def coverage(y_true: np.ndarray, low: np.ndarray, high: np.ndarray) -> float:
104
- return float(np.mean((y_true >= low) & (y_true <= high)) * 100.0)
105
-
106
 
107
- def avg_width(low: np.ndarray, high: np.ndarray) -> float:
108
- return float(np.mean(high - low))
109
 
110
-
111
- # =========================
112
  # Model cache
113
- # =========================
114
  _PIPE = None
115
- _PIPE_META = {"model_id": None, "device": None}
116
 
117
 
118
  def get_pipeline(model_id: str, device: str) -> Chronos2Pipeline:
119
- global _PIPE, _PIPE_META
120
  model_id = (model_id or MODEL_ID_DEFAULT).strip()
121
- device = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu"
122
-
123
- if _PIPE is None or _PIPE_META["model_id"] != model_id or _PIPE_META["device"] != device:
124
  _PIPE = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
125
- _PIPE_META = {"model_id": model_id, "device": device}
126
-
127
  return _PIPE
128
 
129
 
130
- # =========================
131
- # Chronos-2 predict (BULLETPROOF)
132
- # =========================
133
  def _to_numpy(x: Any) -> np.ndarray:
134
  if isinstance(x, np.ndarray):
135
  return x
@@ -139,126 +90,105 @@ def _to_numpy(x: Any) -> np.ndarray:
139
 
140
 
141
  def _extract_samples(raw: Any) -> np.ndarray:
142
- """
143
- Normalizza l’output in np.ndarray.
144
- Possibili output visti in librerie “young”:
145
- - list[list[float]] (samples x horizon)
146
- - list[float] (horizon) -> 1 sample
147
- - np.ndarray / torch.Tensor (horizon) o (samples, horizon)
148
- - dict con chiavi tipo 'samples', 'predictions'
149
- """
150
  if isinstance(raw, dict):
151
  for k in ["samples", "predictions", "prediction", "output"]:
152
  if k in raw:
153
  return _to_numpy(raw[k])
154
- # fallback: prova primo valore
155
  if len(raw) > 0:
156
  return _to_numpy(next(iter(raw.values())))
157
  return np.asarray([], dtype=np.float32)
158
-
159
  return _to_numpy(raw)
160
 
161
 
162
- def chronos2_predict_samples(
163
- pipe: Chronos2Pipeline,
164
- y: np.ndarray,
165
- prediction_length: int,
166
- num_samples_ui: int,
167
- ) -> np.ndarray:
168
  """
169
- Gestisce:
170
- - inputs obbligatorio (posizionale o keyword)
171
- - prediction_length può chiamarsi prediction_length / horizon / steps
172
- - numero campioni può chiamarsi n_samples / num_return_sequences / ...
173
- - oppure non esiste: allora torna 1 sample e noi facciamo broadcast
174
  """
175
- ctx = y.tolist()
176
-
177
  sig = inspect.signature(pipe.predict)
178
  params = sig.parameters
179
 
180
- # 1) name for horizon
 
 
 
181
  horizon_kw = None
182
  for cand in ["prediction_length", "horizon", "steps", "n_steps", "pred_len"]:
183
  if cand in params:
184
  horizon_kw = cand
185
  break
186
 
187
- # 2) name for samples count
188
  sample_kw = None
189
- for cand in ["n_samples", "num_samples", "num_return_sequences", "samples", "n"]:
190
  if cand in params:
191
  sample_kw = cand
192
  break
193
 
194
- # build kwargs
195
  kwargs: Dict[str, Any] = {}
196
- if horizon_kw is not None:
197
- kwargs[horizon_kw] = int(prediction_length)
198
-
199
- # include sample kw only if supported
200
- if sample_kw is not None:
201
- kwargs[sample_kw] = int(num_samples_ui)
202
-
203
- # inputs handling
204
- if "inputs" in params:
205
- raw = pipe.predict(inputs=ctx, **kwargs)
206
  else:
207
- # some builds may accept positional only
208
- raw = pipe.predict(ctx, **kwargs)
209
 
 
 
 
 
 
210
  arr = _extract_samples(raw).astype(np.float32, copy=False)
211
 
212
- # normalize shape -> (samples, horizon)
213
- if arr.ndim == 0:
214
- # degenerate
215
- arr = arr.reshape(1, 1)
216
- elif arr.ndim == 1:
217
- # (horizon,) -> (1, horizon)
218
  arr = arr[None, :]
219
- elif arr.ndim >= 3:
220
- # squeeze extras if any
221
- arr = np.squeeze(arr)
 
 
222
  if arr.ndim == 1:
223
  arr = arr[None, :]
224
 
225
- # ensure horizon matches if possible (some APIs might return longer/shorter)
226
- if arr.shape[1] != int(prediction_length):
227
- # best-effort: trim or pad with last value
228
- h = int(prediction_length)
229
- if arr.shape[1] > h:
230
- arr = arr[:, :h]
231
  else:
232
- pad = h - arr.shape[1]
233
- last = arr[:, -1:]
234
- arr = np.concatenate([arr, np.repeat(last, pad, axis=1)], axis=1)
235
-
236
- # if API didn’t support sample count, we may only have 1 sample: replicate to compute quantiles smoothly
237
- if arr.shape[0] == 1 and num_samples_ui > 1:
238
- arr = np.repeat(arr, repeats=int(num_samples_ui), axis=0)
239
 
240
- return arr
 
 
 
241
 
242
 
243
- # =========================
244
  # Plotly
245
- # =========================
246
- def plot_forecast(y, median, low, high, title, q_low, q_high) -> go.Figure:
247
  t_hist = np.arange(len(y))
248
  t_fcst = np.arange(len(y), len(y) + len(median))
249
 
250
  fig = go.Figure()
251
  fig.add_trace(go.Scatter(x=t_hist, y=y, mode="lines", name="History"))
252
-
253
- fig.add_trace(go.Scatter(x=t_fcst, y=high, mode="lines", line=dict(width=0),
254
- showlegend=False, hoverinfo="skip"))
255
- fig.add_trace(go.Scatter(
256
- x=t_fcst, y=low, mode="lines", fill="tonexty",
257
- line=dict(width=0), name=f"Band [{q_low:.2f}, {q_high:.2f}]"
258
- ))
259
  fig.add_trace(go.Scatter(x=t_fcst, y=median, mode="lines", name="Forecast (median)"))
260
  fig.add_vline(x=len(y) - 1, line_width=1, line_dash="dash", opacity=0.6)
261
 
 
 
 
 
 
 
 
 
262
  fig.update_layout(
263
  title=title,
264
  hovermode="x unified",
@@ -270,46 +200,23 @@ def plot_forecast(y, median, low, high, title, q_low, q_high) -> go.Figure:
270
  return fig
271
 
272
 
273
- def plot_backtest(y_train, y_true, pred, low, high, q_low, q_high) -> go.Figure:
274
- t_train = np.arange(len(y_train))
275
- t_test = np.arange(len(y_train), len(y_train) + len(y_true))
 
 
 
 
 
 
 
276
 
277
- fig = go.Figure()
278
- fig.add_trace(go.Scatter(x=t_train, y=y_train, mode="lines", name="Train"))
279
- fig.add_trace(go.Scatter(x=t_test, y=y_true, mode="lines", name="True (holdout)"))
280
-
281
- fig.add_trace(go.Scatter(x=t_test, y=high, mode="lines", line=dict(width=0),
282
- showlegend=False, hoverinfo="skip"))
283
- fig.add_trace(go.Scatter(
284
- x=t_test, y=low, mode="lines", fill="tonexty",
285
- line=dict(width=0), name=f"Band [{q_low:.2f}, {q_high:.2f}]"
286
- ))
287
- fig.add_trace(go.Scatter(x=t_test, y=pred, mode="lines", name="Pred (median)"))
288
-
289
- fig.add_vline(x=len(y_train) - 1, line_width=1, line_dash="dash", opacity=0.6)
290
- fig.update_layout(
291
- title="Backtest (holdout) — interactive",
292
- hovermode="x unified",
293
- margin=dict(l=10, r=10, t=55, b=10),
294
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
295
- xaxis_title="t",
296
- yaxis_title="value",
297
- )
298
- return fig
299
 
300
 
301
- # =========================
302
- # Natural language explanation
303
- # =========================
304
- def explain_output(
305
- y: np.ndarray,
306
- median: np.ndarray,
307
- low: np.ndarray,
308
- high: np.ndarray,
309
- q_low: float,
310
- q_high: float,
311
- backtest: Optional[Dict[str, float]],
312
- ) -> str:
313
  horizon = len(median)
314
  base = float(np.mean(y))
315
  delta = float(median[-1] - median[0])
@@ -322,280 +229,131 @@ def explain_output(
322
  else:
323
  trend_txt = "in calo"
324
 
325
- w = float(np.mean(high - low))
326
- rel_w = (w / max(1e-6, float(np.mean(median)))) * 100.0
327
- if rel_w < 10:
328
- uncert_txt = "bassa"
329
- elif rel_w < 25:
330
- uncert_txt = "moderata"
331
- else:
332
- uncert_txt = "alta"
333
-
334
  txt = f"""
335
- ### 🧠 Spiegazione (linguaggio naturale)
336
-
337
- **Cosa sta dicendo il modello:** nei prossimi **{horizon} step** la serie è **{trend_txt}** (variazione mediana complessiva ≈ **{pct:+.1f}%** rispetto al livello medio storico).
338
 
339
- - **Mediana all’ultimo step:** **{median[-1]:.2f}**
340
- - **Intervallo [{q_low:.0%}–{q_high:.0%}] all’ultimo step:** **[{low[-1]:.2f} – {high[-1]:.2f}]**
341
- - **Incertezza:** **{uncert_txt}** (larghezza media banda ≈ **{w:.2f}**, ~**{rel_w:.1f}%** della mediana)
342
-
343
- **Come usarlo:** usa la **mediana** come previsione “baseline”; usa il **quantile alto** per scenari prudenziali (es. scorte/capacità) e il **quantile basso** per scenari conservativi (es. budget).
344
  """
 
 
 
 
345
 
346
- if backtest:
347
- target_cov = (q_high - q_low) * 100.0
348
- cov = backtest["coverage"]
349
- calib = "buona" if abs(cov - target_cov) <= 10 else "migliorabile"
350
- txt += f"""
351
-
352
- ### 🧪 Affidabilità (backtest)
353
-
354
- Sul tratto holdout:
355
- - **MAE:** {backtest["mae"]:.3f}
356
- - **RMSE:** {backtest["rmse"]:.3f}
357
- - **MAPE:** {backtest["mape"]:.2f}%
358
- - **Coverage:** {cov:.1f}% (target atteso ≈ {target_cov:.1f}%)
359
-
360
- Interpretazione: la banda di incertezza ha una calibrazione **{calib}** sul passato recente.
361
- """
362
  return txt
363
 
364
 
365
- # =========================
366
- # KPI HTML
367
- # =========================
368
- def kpi_card(label: str, value: str, hint: str = "") -> str:
369
- hint_html = f"<div style='opacity:.75;font-size:12px;margin-top:6px;'>{hint}</div>" if hint else ""
370
- return f"""
371
- <div style="border:1px solid rgba(255,255,255,.12); border-radius:16px; padding:14px 16px;
372
- background: rgba(255,255,255,.04); backdrop-filter: blur(6px);">
373
- <div style="font-size:12px;opacity:.8;">{label}</div>
374
- <div style="font-size:22px;font-weight:700;margin-top:4px;">{value}</div>
375
- {hint_html}
376
- </div>
377
- """
378
-
379
-
380
- def kpi_grid(cards: List[str]) -> str:
381
- return f"<div style='display:grid; grid-template-columns: repeat(6, minmax(0, 1fr)); gap:12px;'>{''.join(cards)}</div>"
382
-
383
-
384
- @dataclass
385
- class Outputs:
386
- kpis_html: str
387
- explanation_md: str
388
- forecast_fig: go.Figure
389
- backtest_fig: go.Figure
390
- forecast_table: pd.DataFrame
391
- backtest_table: pd.DataFrame
392
- forecast_csv_path: str
393
- backtest_csv_path: Optional[str]
394
- info: dict
395
-
396
-
397
- # =========================
398
- # Core run
399
- # =========================
400
- def run_dashboard(
401
- input_mode: str,
402
- test_csv_name: str,
403
- upload_csv,
404
- csv_column: str,
405
-
406
- n: int,
407
- seed: int,
408
- trend: float,
409
- season_period: int,
410
- season_amp: float,
411
- noise: float,
412
-
413
- prediction_length: int,
414
- num_samples: int,
415
- q_low: float,
416
- q_high: float,
417
-
418
- do_backtest: bool,
419
- holdout: int,
420
-
421
- device_ui: str,
422
- model_id: str,
423
- ) -> Outputs:
424
  if q_low >= q_high:
425
  raise gr.Error("Quantile low deve essere < quantile high.")
426
 
427
  device = pick_device(device_ui)
 
428
 
429
- # Load series
430
  if input_mode == "Test CSV":
431
  if not test_csv_name:
432
  raise gr.Error("Seleziona un Test CSV.")
433
  path = os.path.join(DATA_DIR, test_csv_name)
434
- if not os.path.exists(path):
435
- raise gr.Error(f"File non trovato: {path}")
436
- y, used_col, _ = load_series_from_csv(path, csv_column)
437
  source = f"Test CSV: {test_csv_name} • col={used_col}"
438
-
439
  elif input_mode == "Upload CSV":
440
  if upload_csv is None:
441
  raise gr.Error("Carica un CSV.")
442
- y, used_col, _ = load_series_from_csv(upload_csv.name, csv_column)
443
  source = f"Upload CSV • col={used_col}"
444
-
445
  else:
446
  y = make_sample_series(n, seed, trend, season_period, season_amp, noise)
447
  source = "Sample series"
448
 
449
- if do_backtest and holdout >= len(y):
450
- raise gr.Error("Holdout deve essere più piccolo della lunghezza dello storico.")
451
-
452
  t0 = time.time()
453
- pipe = get_pipeline(model_id, device)
 
454
 
455
- # Forecast (samples x horizon)
456
- samples = chronos2_predict_samples(pipe, y, int(prediction_length), int(num_samples))
457
  median = np.quantile(samples, 0.50, axis=0)
458
- low = np.quantile(samples, float(q_low), axis=0)
459
- high = np.quantile(samples, float(q_high), axis=0)
 
 
 
 
 
460
 
461
- # Tables & export
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  t_fcst = np.arange(len(y), len(y) + int(prediction_length))
463
- forecast_df = pd.DataFrame({
464
  "t": t_fcst,
465
  "median": median,
466
- f"q{q_low:.2f}": low,
467
- f"q{q_high:.2f}": high,
468
  })
469
- forecast_csv_path = os.path.join(OUT_DIR, "chronos2_forecast.csv")
470
- forecast_df.to_csv(forecast_csv_path, index=False)
 
471
 
472
- forecast_fig = plot_forecast(y, median, low, high, f"Forecast — {source}", q_low, q_high)
 
473
 
474
- # Backtest optional
475
- empty_backtest_fig = go.Figure().update_layout(
476
- title="Backtest disabled",
477
- margin=dict(l=10, r=10, t=55, b=10),
478
- )
479
- backtest_fig = empty_backtest_fig
480
- backtest_df = pd.DataFrame()
481
- backtest_csv_path = None
482
- backtest_metrics = None
483
-
484
- # KPIs base
485
- elapsed = time.time() - t0
486
- cards = [
487
- kpi_card("Device", device.upper(), f"cuda_available={torch.cuda.is_available()}"),
488
- kpi_card("Model", (model_id or MODEL_ID_DEFAULT), "Chronos-2"),
489
- kpi_card("Latency", f"{elapsed:.2f}s", "cached after first run"),
490
- kpi_card("Samples (req)", f"{int(num_samples)}", "requested"),
491
- kpi_card("Interval", f"[{q_low:.2f}, {q_high:.2f}]", "uncertainty band"),
492
- kpi_card("Band width", f"{avg_width(low, high):.3f}", "forecast band"),
493
- ]
494
-
495
- if do_backtest:
496
- y_train = y[:-int(holdout)]
497
- y_true = y[-int(holdout):]
498
-
499
- bt_samples = chronos2_predict_samples(pipe, y_train, int(holdout), int(num_samples))
500
- bt_med = np.quantile(bt_samples, 0.50, axis=0)
501
- bt_low = np.quantile(bt_samples, float(q_low), axis=0)
502
- bt_high = np.quantile(bt_samples, float(q_high), axis=0)
503
-
504
- bt_mae = mae(y_true, bt_med)
505
- bt_rmse = rmse(y_true, bt_med)
506
- bt_mape = mape(y_true, bt_med)
507
- bt_cov = coverage(y_true, bt_low, bt_high)
508
- bt_w = avg_width(bt_low, bt_high)
509
-
510
- backtest_metrics = {"mae": bt_mae, "rmse": bt_rmse, "mape": bt_mape, "coverage": bt_cov}
511
-
512
- cards += [
513
- kpi_card("BT MAE", f"{bt_mae:.3f}", f"holdout={holdout}"),
514
- kpi_card("BT RMSE", f"{bt_rmse:.3f}"),
515
- kpi_card("BT MAPE", f"{bt_mape:.2f}%"),
516
- kpi_card("Coverage", f"{bt_cov:.1f}%", "inside band"),
517
- kpi_card("BT width", f"{bt_w:.3f}", "avg band"),
518
- ]
519
-
520
- backtest_fig = plot_backtest(y_train, y_true, bt_med, bt_low, bt_high, q_low, q_high)
521
-
522
- t_test = np.arange(len(y_train), len(y_train) + int(holdout))
523
- backtest_df = pd.DataFrame({
524
- "t": t_test,
525
- "true": y_true,
526
- "pred_median": bt_med,
527
- f"q{q_low:.2f}": bt_low,
528
- f"q{q_high:.2f}": bt_high,
529
- })
530
- backtest_csv_path = os.path.join(OUT_DIR, "chronos2_backtest.csv")
531
- backtest_df.to_csv(backtest_csv_path, index=False)
532
-
533
- explanation_md = explain_output(y, median, low, high, q_low, q_high, backtest_metrics)
534
 
535
  info = {
536
  "source": source,
537
  "history_points": int(len(y)),
538
  "prediction_length": int(prediction_length),
539
- "num_samples_requested": int(num_samples),
540
- "q_low": float(q_low),
541
- "q_high": float(q_high),
542
- "backtest": bool(do_backtest),
543
- "holdout": int(holdout) if do_backtest else None,
544
  "predict_signature": str(inspect.signature(pipe.predict)),
 
545
  }
546
 
547
- return Outputs(
548
- kpis_html=kpi_grid(cards),
549
- explanation_md=explanation_md,
550
- forecast_fig=forecast_fig,
551
- backtest_fig=backtest_fig,
552
- forecast_table=forecast_df,
553
- backtest_table=backtest_df,
554
- forecast_csv_path=forecast_csv_path,
555
- backtest_csv_path=backtest_csv_path,
556
- info=info,
557
- )
558
-
559
-
560
- def run_wrapped(*args):
561
- out = run_dashboard(*args)
562
- return (
563
- out.kpis_html,
564
- out.explanation_md,
565
- out.forecast_fig,
566
- out.backtest_fig,
567
- out.forecast_table,
568
- out.backtest_table,
569
- out.forecast_csv_path,
570
- out.backtest_csv_path,
571
- out.info,
572
- )
573
 
574
 
575
- # =========================
576
  # UI
577
- # =========================
578
- css = """
579
- .gradio-container { max-width: 1200px !important; }
580
- """
581
 
582
- with gr.Blocks(title="Chronos-2 • Forecast Dashboard", css=css) as demo:
583
- gr.Markdown(
584
- """
585
- # ⏱️ Chronos-2 Forecast Dashboard (Bulletproof)
586
- Plotly interattivo + KPI + backtest + export + spiegazione in linguaggio naturale.
587
- """
588
- )
589
 
590
  with gr.Row():
591
  with gr.Column(scale=1, min_width=360):
592
- gr.Markdown("## Input")
593
- input_mode = gr.Radio(["Sample", "Test CSV", "Upload CSV"], value="Sample", label="Sorgente dati")
594
  test_csv_name = gr.Dropdown(choices=available_test_csv(), label="Test CSV (data/)")
595
  upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
596
  csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value")
597
 
598
- gr.Markdown("## Sistema")
599
  device_ui = gr.Dropdown(
600
  ["cpu", "cuda (se disponibile)"],
601
  value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu",
@@ -611,61 +369,35 @@ Plotly interattivo + KPI + backtest + export + spiegazione in linguaggio natural
611
  season_amp = gr.Slider(0.0, 12.0, value=3.0, step=0.1, label="Season amplitude")
612
  noise = gr.Slider(0.0, 6.0, value=0.8, step=0.05, label="Noise")
613
 
614
- gr.Markdown("## Forecast")
615
  prediction_length = gr.Slider(1, 365, value=30, step=1, label="Prediction length")
616
- num_samples = gr.Slider(1, 800, value=300, step=25, label="Num samples (requested)")
617
  q_low = gr.Slider(0.01, 0.49, value=0.10, step=0.01, label="Quantile low")
618
  q_high = gr.Slider(0.51, 0.99, value=0.90, step=0.01, label="Quantile high")
619
 
620
- gr.Markdown("## Backtest")
621
- do_backtest = gr.Checkbox(value=True, label="Esegui backtest holdout")
622
- holdout = gr.Slider(5, 365, value=30, step=1, label="Holdout points")
623
-
624
  run_btn = gr.Button("Run", variant="primary")
625
 
626
  with gr.Column(scale=2):
627
- gr.Markdown("## KPI")
628
  kpis = gr.HTML()
629
-
630
  with gr.Tabs():
631
  with gr.Tab("Forecast"):
632
- forecast_plot = gr.Plot(label="Forecast (interactive)")
633
- forecast_table = gr.Dataframe(label="Forecast table", interactive=False)
634
-
635
- with gr.Tab("Backtest"):
636
- backtest_plot = gr.Plot(label="Backtest (interactive)")
637
- backtest_table = gr.Dataframe(label="Backtest table", interactive=False)
638
-
639
  with gr.Tab("Spiegazione"):
640
  explanation = gr.Markdown()
641
-
642
  with gr.Tab("Export"):
643
- forecast_download = gr.File(label="Forecast CSV")
644
- backtest_download = gr.File(label="Backtest CSV")
645
-
646
- with gr.Tab("Run info"):
647
- run_info = gr.JSON(label="Info")
648
 
649
  run_btn.click(
650
- fn=run_wrapped,
651
  inputs=[
652
  input_mode, test_csv_name, upload_csv, csv_column,
653
  n, seed, trend, season_period, season_amp, noise,
654
- prediction_length, num_samples, q_low, q_high,
655
- do_backtest, holdout,
656
  device_ui, model_id,
657
  ],
658
- outputs=[
659
- kpis,
660
- explanation,
661
- forecast_plot,
662
- backtest_plot,
663
- forecast_table,
664
- backtest_table,
665
- forecast_download,
666
- backtest_download,
667
- run_info,
668
- ],
669
  )
670
 
671
  demo.queue()
 
1
  import os
2
  import time
3
  import inspect
4
+ from typing import Any, Dict, List, Optional, Tuple
 
5
 
6
  import numpy as np
7
  import pandas as pd
 
12
  from chronos import Chronos2Pipeline
13
 
14
 
 
 
 
15
  MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
16
  DATA_DIR = "data"
17
  OUT_DIR = "/tmp"
18
 
19
 
20
+ # -------------------------
21
+ # Data
22
+ # -------------------------
23
  def available_test_csv() -> List[str]:
24
  if not os.path.isdir(DATA_DIR):
25
  return []
 
27
 
28
 
29
  def pick_device(ui_choice: str) -> str:
30
+ return "cuda" if (ui_choice or "").startswith("cuda") and torch.cuda.is_available() else "cpu"
31
+
32
+
33
+ def make_sample_series(n: int, seed: int, trend: float, season_period: int, season_amp: float, noise: float) -> np.ndarray:
 
 
 
 
 
 
 
 
 
34
  rng = np.random.default_rng(int(seed))
35
  t = np.arange(int(n), dtype=np.float32)
36
+ y = (trend * t + season_amp * np.sin(2 * np.pi * t / max(1, int(season_period))) + rng.normal(0, noise, size=int(n))).astype(np.float32)
 
 
 
 
37
  if float(np.min(y)) < 0:
38
+ y -= float(np.min(y))
39
  return y
40
 
41
 
42
+ def load_series_from_csv(csv_path: str, column: Optional[str]) -> Tuple[np.ndarray, str]:
43
  df = pd.read_csv(csv_path)
 
 
 
44
  col = (column or "").strip()
45
  if not col:
 
46
  numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
47
  if not numeric_cols:
48
+ # try coercion
49
  for c in df.columns:
50
  coerced = pd.to_numeric(df[c], errors="coerce")
51
  if coerced.notna().sum() > 0:
 
53
  if not numeric_cols:
54
  raise ValueError("Non trovo colonne numeriche nel CSV.")
55
  col = numeric_cols[0]
 
56
  if col not in df.columns:
57
  raise ValueError(f"Colonna '{col}' non trovata. Disponibili: {list(df.columns)}")
 
58
  y = pd.to_numeric(df[col], errors="coerce").dropna().astype(np.float32).to_numpy()
59
  if len(y) < 10:
60
+ raise ValueError("Serie troppo corta.")
61
+ return y, col
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
63
 
64
+ # -------------------------
 
65
  # Model cache
66
+ # -------------------------
67
  _PIPE = None
68
+ _META = {"model_id": None, "device": None}
69
 
70
 
71
  def get_pipeline(model_id: str, device: str) -> Chronos2Pipeline:
72
+ global _PIPE, _META
73
  model_id = (model_id or MODEL_ID_DEFAULT).strip()
74
+ device = "cuda" if device == "cuda" and torch.cuda.is_available() else "cpu"
75
+ if _PIPE is None or _META["model_id"] != model_id or _META["device"] != device:
 
76
  _PIPE = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
77
+ _META = {"model_id": model_id, "device": device}
 
78
  return _PIPE
79
 
80
 
81
+ # -------------------------
82
+ # Predict (STABLE)
83
+ # -------------------------
84
  def _to_numpy(x: Any) -> np.ndarray:
85
  if isinstance(x, np.ndarray):
86
  return x
 
90
 
91
 
92
  def _extract_samples(raw: Any) -> np.ndarray:
 
 
 
 
 
 
 
 
93
  if isinstance(raw, dict):
94
  for k in ["samples", "predictions", "prediction", "output"]:
95
  if k in raw:
96
  return _to_numpy(raw[k])
 
97
  if len(raw) > 0:
98
  return _to_numpy(next(iter(raw.values())))
99
  return np.asarray([], dtype=np.float32)
 
100
  return _to_numpy(raw)
101
 
102
 
103
+ def chronos2_predict(pipe: Chronos2Pipeline, y: np.ndarray, horizon: int, requested_samples: int) -> Tuple[np.ndarray, bool, str]:
 
 
 
 
 
104
  """
105
+ Returns:
106
+ samples: (S, H)
107
+ multi: whether S>1 is real (not replicated)
108
+ note: debug note
 
109
  """
 
 
110
  sig = inspect.signature(pipe.predict)
111
  params = sig.parameters
112
 
113
+ # input format: ALWAYS batch = [series]
114
+ inputs = [y.tolist()]
115
+
116
+ # kw for horizon
117
  horizon_kw = None
118
  for cand in ["prediction_length", "horizon", "steps", "n_steps", "pred_len"]:
119
  if cand in params:
120
  horizon_kw = cand
121
  break
122
 
123
+ # kw for samples count (many versions don't have it!)
124
  sample_kw = None
125
+ for cand in ["n_samples", "num_return_sequences", "num_samples"]:
126
  if cand in params:
127
  sample_kw = cand
128
  break
129
 
 
130
  kwargs: Dict[str, Any] = {}
131
+ if horizon_kw:
132
+ kwargs[horizon_kw] = int(horizon)
 
 
 
 
 
 
 
 
133
  else:
134
+ # worst case: try positional horizon if supported (rare)
135
+ kwargs["prediction_length"] = int(horizon)
136
 
137
+ if sample_kw:
138
+ kwargs[sample_kw] = int(requested_samples)
139
+
140
+ # call
141
+ raw = pipe.predict(inputs=inputs, **kwargs) if "inputs" in params else pipe.predict(inputs, **kwargs)
142
  arr = _extract_samples(raw).astype(np.float32, copy=False)
143
 
144
+ # normalize shape -> (S,H)
145
+ arr = np.squeeze(arr)
146
+ if arr.ndim == 1:
147
+ # could be (H,) or (S,) - assume horizon if length == H
 
 
148
  arr = arr[None, :]
149
+
150
+ # Sometimes output is (B,S,H) or (B,H). If batch dim exists, take first
151
+ if arr.ndim == 3:
152
+ # assume (B,S,H) or (S,B,H); safest: pick first on axis=0
153
+ arr = arr[0]
154
  if arr.ndim == 1:
155
  arr = arr[None, :]
156
 
157
+ # ensure horizon length
158
+ if arr.shape[-1] != horizon:
159
+ if arr.shape[-1] > horizon:
160
+ arr = arr[..., :horizon]
 
 
161
  else:
162
+ pad = horizon - arr.shape[-1]
163
+ last = arr[..., -1:]
164
+ arr = np.concatenate([arr, np.repeat(last, pad, axis=-1)], axis=-1)
 
 
 
 
165
 
166
+ # If we got only 1 sample, we can still plot median but band is not meaningful
167
+ real_multi = arr.shape[0] > 1
168
+ note = f"predict_signature={sig} | used_horizon_kw={horizon_kw} | used_sample_kw={sample_kw} | got_shape={tuple(arr.shape)}"
169
+ return arr, real_multi, note
170
 
171
 
172
+ # -------------------------
173
  # Plotly
174
+ # -------------------------
175
+ def plot_forecast(y, median, low, high, title, show_band: bool, band_label: str) -> go.Figure:
176
  t_hist = np.arange(len(y))
177
  t_fcst = np.arange(len(y), len(y) + len(median))
178
 
179
  fig = go.Figure()
180
  fig.add_trace(go.Scatter(x=t_hist, y=y, mode="lines", name="History"))
 
 
 
 
 
 
 
181
  fig.add_trace(go.Scatter(x=t_fcst, y=median, mode="lines", name="Forecast (median)"))
182
  fig.add_vline(x=len(y) - 1, line_width=1, line_dash="dash", opacity=0.6)
183
 
184
+ if show_band:
185
+ fig.add_trace(go.Scatter(x=t_fcst, y=high, mode="lines", line=dict(width=0),
186
+ showlegend=False, hoverinfo="skip"))
187
+ fig.add_trace(go.Scatter(
188
+ x=t_fcst, y=low, mode="lines", fill="tonexty",
189
+ line=dict(width=0), name=band_label
190
+ ))
191
+
192
  fig.update_layout(
193
  title=title,
194
  hovermode="x unified",
 
200
  return fig
201
 
202
 
203
+ def kpi_card(label: str, value: str, hint: str = "") -> str:
204
+ hint_html = f"<div style='opacity:.75;font-size:12px;margin-top:6px;'>{hint}</div>" if hint else ""
205
+ return f"""
206
+ <div style="border:1px solid rgba(255,255,255,.12); border-radius:16px; padding:14px 16px;
207
+ background: rgba(255,255,255,.04);">
208
+ <div style="font-size:12px;opacity:.8;">{label}</div>
209
+ <div style="font-size:22px;font-weight:700;margin-top:4px;">{value}</div>
210
+ {hint_html}
211
+ </div>
212
+ """
213
 
214
+
215
+ def kpi_grid(cards: List[str]) -> str:
216
+ return f"<div style='display:grid; grid-template-columns: repeat(6, minmax(0, 1fr)); gap:12px;'>{''.join(cards)}</div>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
 
219
+ def explain(y, median, low, high, band_enabled: bool, q_low: float, q_high: float, extra: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
220
  horizon = len(median)
221
  base = float(np.mean(y))
222
  delta = float(median[-1] - median[0])
 
229
  else:
230
  trend_txt = "in calo"
231
 
 
 
 
 
 
 
 
 
 
232
  txt = f"""
233
+ ### 🧠 Spiegazione
 
 
234
 
235
+ Nei prossimi **{horizon} step** la previsione mediana è **{trend_txt}** (variazione ≈ **{pct:+.1f}%** rispetto al livello medio storico).
236
+ - **Ultimo valore mediano previsto:** **{median[-1]:.2f}**
 
 
 
237
  """
238
+ if band_enabled:
239
+ txt += f"- **Banda [{q_low:.0%}–{q_high:.0%}] (ultimo step):** **[{low[-1]:.2f} – {high[-1]:.2f}]**\n"
240
+ else:
241
+ txt += "- **Banda di incertezza:** disattivata (questa versione di Chronos2 non restituisce campioni multipli con i parametri disponibili).\n"
242
 
243
+ txt += f"\n<details><summary>Debug</summary>\n\n`{extra}`\n\n</details>\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  return txt
245
 
246
 
247
+ # -------------------------
248
+ # Run
249
+ # -------------------------
250
+ def run_all(
251
+ input_mode, test_csv_name, upload_csv, csv_column,
252
+ n, seed, trend, season_period, season_amp, noise,
253
+ prediction_length, requested_samples, q_low, q_high,
254
+ device_ui, model_id,
255
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  if q_low >= q_high:
257
  raise gr.Error("Quantile low deve essere < quantile high.")
258
 
259
  device = pick_device(device_ui)
260
+ pipe = get_pipeline(model_id, device)
261
 
262
+ # data
263
  if input_mode == "Test CSV":
264
  if not test_csv_name:
265
  raise gr.Error("Seleziona un Test CSV.")
266
  path = os.path.join(DATA_DIR, test_csv_name)
267
+ y, used_col = load_series_from_csv(path, csv_column)
 
 
268
  source = f"Test CSV: {test_csv_name} • col={used_col}"
 
269
  elif input_mode == "Upload CSV":
270
  if upload_csv is None:
271
  raise gr.Error("Carica un CSV.")
272
+ y, used_col = load_series_from_csv(upload_csv.name, csv_column)
273
  source = f"Upload CSV • col={used_col}"
 
274
  else:
275
  y = make_sample_series(n, seed, trend, season_period, season_amp, noise)
276
  source = "Sample series"
277
 
 
 
 
278
  t0 = time.time()
279
+ samples, real_multi, note = chronos2_predict(pipe, y, int(prediction_length), int(requested_samples))
280
+ latency = time.time() - t0
281
 
 
 
282
  median = np.quantile(samples, 0.50, axis=0)
283
+ band_enabled = real_multi and samples.shape[0] > 2
284
+ if band_enabled:
285
+ low = np.quantile(samples, float(q_low), axis=0)
286
+ high = np.quantile(samples, float(q_high), axis=0)
287
+ else:
288
+ low = median.copy()
289
+ high = median.copy()
290
 
291
+ # KPI
292
+ cards = [
293
+ kpi_card("Device", device.upper(), f"cuda_available={torch.cuda.is_available()}"),
294
+ kpi_card("Latency", f"{latency:.2f}s", "predict()"),
295
+ kpi_card("Samples", str(samples.shape[0]), "returned by model"),
296
+ kpi_card("Band", "ON" if band_enabled else "OFF", "needs multi-samples"),
297
+ kpi_card("Horizon", str(prediction_length)),
298
+ kpi_card("Model", (model_id or MODEL_ID_DEFAULT)),
299
+ ]
300
+ kpis_html = kpi_grid(cards)
301
+
302
+ # Plot
303
+ fig = plot_forecast(
304
+ y=y,
305
+ median=median,
306
+ low=low,
307
+ high=high,
308
+ title=f"Forecast — {source}",
309
+ show_band=band_enabled,
310
+ band_label=f"Band [{q_low:.2f}, {q_high:.2f}]",
311
+ )
312
+
313
+ # Table + export
314
  t_fcst = np.arange(len(y), len(y) + int(prediction_length))
315
+ out_df = pd.DataFrame({
316
  "t": t_fcst,
317
  "median": median,
 
 
318
  })
319
+ if band_enabled:
320
+ out_df[f"q{q_low:.2f}"] = low
321
+ out_df[f"q{q_high:.2f}"] = high
322
 
323
+ out_path = os.path.join(OUT_DIR, "chronos2_forecast.csv")
324
+ out_df.to_csv(out_path, index=False)
325
 
326
+ explanation_md = explain(y, median, low, high, band_enabled, q_low, q_high, note)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  info = {
329
  "source": source,
330
  "history_points": int(len(y)),
331
  "prediction_length": int(prediction_length),
332
+ "requested_samples": int(requested_samples),
333
+ "returned_samples": int(samples.shape[0]),
334
+ "band_enabled": bool(band_enabled),
 
 
335
  "predict_signature": str(inspect.signature(pipe.predict)),
336
+ "debug_note": note,
337
  }
338
 
339
+ return kpis_html, explanation_md, fig, out_df, out_path, info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
 
342
+ # -------------------------
343
  # UI
344
+ # -------------------------
345
+ css = """.gradio-container { max-width: 1200px !important; }"""
 
 
346
 
347
+ with gr.Blocks(title="Chronos-2 • Pro Dashboard (Stable)", css=css) as demo:
348
+ gr.Markdown("# ⏱️ Chronos-2 Forecast Dashboard — Stable Edition")
 
 
 
 
 
349
 
350
  with gr.Row():
351
  with gr.Column(scale=1, min_width=360):
352
+ input_mode = gr.Radio(["Sample", "Test CSV", "Upload CSV"], value="Sample", label="Input")
 
353
  test_csv_name = gr.Dropdown(choices=available_test_csv(), label="Test CSV (data/)")
354
  upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
355
  csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value")
356
 
 
357
  device_ui = gr.Dropdown(
358
  ["cpu", "cuda (se disponibile)"],
359
  value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu",
 
369
  season_amp = gr.Slider(0.0, 12.0, value=3.0, step=0.1, label="Season amplitude")
370
  noise = gr.Slider(0.0, 6.0, value=0.8, step=0.05, label="Noise")
371
 
 
372
  prediction_length = gr.Slider(1, 365, value=30, step=1, label="Prediction length")
373
+ requested_samples = gr.Slider(1, 800, value=200, step=25, label="Requested samples (best effort)")
374
  q_low = gr.Slider(0.01, 0.49, value=0.10, step=0.01, label="Quantile low")
375
  q_high = gr.Slider(0.51, 0.99, value=0.90, step=0.01, label="Quantile high")
376
 
 
 
 
 
377
  run_btn = gr.Button("Run", variant="primary")
378
 
379
  with gr.Column(scale=2):
 
380
  kpis = gr.HTML()
 
381
  with gr.Tabs():
382
  with gr.Tab("Forecast"):
383
+ forecast_plot = gr.Plot()
384
+ forecast_table = gr.Dataframe(interactive=False)
 
 
 
 
 
385
  with gr.Tab("Spiegazione"):
386
  explanation = gr.Markdown()
 
387
  with gr.Tab("Export"):
388
+ download = gr.File()
389
+ with gr.Tab("Info"):
390
+ info = gr.JSON()
 
 
391
 
392
  run_btn.click(
393
+ fn=run_all,
394
  inputs=[
395
  input_mode, test_csv_name, upload_csv, csv_column,
396
  n, seed, trend, season_period, season_amp, noise,
397
+ prediction_length, requested_samples, q_low, q_high,
 
398
  device_ui, model_id,
399
  ],
400
+ outputs=[kpis, explanation, forecast_plot, forecast_table, download, info],
 
 
 
 
 
 
 
 
 
 
401
  )
402
 
403
  demo.queue()