Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,9 +1,14 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
import gradio as gr
|
| 5 |
-
import matplotlib.pyplot as plt
|
| 6 |
import torch
|
|
|
|
| 7 |
|
| 8 |
from chronos import Chronos2Pipeline
|
| 9 |
|
|
@@ -13,322 +18,604 @@ from chronos import Chronos2Pipeline
|
|
| 13 |
# =========================
|
| 14 |
MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
|
| 15 |
DATA_DIR = "data"
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
# =========================
|
| 19 |
-
#
|
| 20 |
# =========================
|
| 21 |
-
def available_test_csv():
|
| 22 |
if not os.path.isdir(DATA_DIR):
|
| 23 |
return []
|
| 24 |
-
return sorted(f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv"))
|
| 25 |
|
| 26 |
|
| 27 |
def pick_device(ui_choice: str) -> str:
|
| 28 |
-
if
|
| 29 |
return "cuda"
|
| 30 |
return "cpu"
|
| 31 |
|
| 32 |
|
| 33 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
rng = np.random.default_rng(int(seed))
|
| 35 |
-
t = np.arange(int(n))
|
| 36 |
y = (
|
| 37 |
float(trend) * t
|
| 38 |
+ float(season_amp) * np.sin(2 * np.pi * t / max(1, int(season_period)))
|
| 39 |
-
+ rng.normal(0.0, float(noise), size=
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
return y.astype(np.float32)
|
| 46 |
|
| 47 |
|
| 48 |
-
def load_series_from_csv(
|
| 49 |
-
df = pd.read_csv(
|
| 50 |
if df.shape[1] == 0:
|
| 51 |
raise ValueError("CSV vuoto o non leggibile.")
|
| 52 |
|
| 53 |
col = (column or "").strip()
|
| 54 |
-
if col
|
| 55 |
-
# try native numeric dtypes first
|
| 56 |
numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
|
| 57 |
-
# fallback: try coercion
|
| 58 |
if not numeric_cols:
|
|
|
|
| 59 |
for c in df.columns:
|
| 60 |
coerced = pd.to_numeric(df[c], errors="coerce")
|
| 61 |
-
if coerced.notna().sum()
|
| 62 |
numeric_cols.append(c)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
col = numeric_cols[0]
|
| 66 |
|
| 67 |
if col not in df.columns:
|
| 68 |
-
raise ValueError(f"Colonna '{col}' non trovata.
|
| 69 |
|
| 70 |
-
y = pd.to_numeric(df[col], errors="coerce").dropna().to_numpy()
|
| 71 |
if len(y) < 10:
|
| 72 |
-
raise ValueError("Serie troppo corta (minimo
|
| 73 |
|
| 74 |
-
return y
|
| 75 |
|
| 76 |
|
| 77 |
# =========================
|
| 78 |
-
#
|
| 79 |
# =========================
|
| 80 |
_PIPELINE = None
|
| 81 |
-
_PIPELINE_META = {}
|
| 82 |
|
| 83 |
|
| 84 |
-
def get_pipeline(model_id: str, device: str):
|
| 85 |
global _PIPELINE, _PIPELINE_META
|
| 86 |
|
| 87 |
model_id = (model_id or MODEL_ID_DEFAULT).strip()
|
| 88 |
-
device = "cuda" if device == "cuda" and torch.cuda.is_available() else "cpu"
|
| 89 |
-
|
| 90 |
-
if
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
or _PIPELINE_META.get("device") != device
|
| 94 |
-
):
|
| 95 |
-
_PIPELINE = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
|
| 96 |
_PIPELINE_META = {"model_id": model_id, "device": device}
|
| 97 |
|
| 98 |
return _PIPELINE
|
| 99 |
|
| 100 |
|
| 101 |
# =========================
|
| 102 |
-
#
|
| 103 |
# =========================
|
| 104 |
-
def
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
|
| 149 |
# =========================
|
| 150 |
-
#
|
| 151 |
# =========================
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
upload_csv,
|
| 156 |
-
csv_column,
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
q_low
|
| 170 |
-
q_high
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
if q_low >= q_high:
|
| 172 |
raise gr.Error("Quantile low deve essere < quantile high.")
|
| 173 |
|
| 174 |
device = pick_device(device_ui)
|
| 175 |
-
pipe = get_pipeline(model_id, device)
|
| 176 |
|
| 177 |
-
#
|
|
|
|
|
|
|
| 178 |
if input_mode == "Test CSV":
|
| 179 |
if not test_csv_name:
|
| 180 |
-
raise gr.Error("Seleziona un
|
| 181 |
-
|
| 182 |
-
if not os.path.exists(
|
| 183 |
-
raise gr.Error(f"
|
| 184 |
-
y, used_col = load_series_from_csv(
|
| 185 |
-
|
| 186 |
|
| 187 |
elif input_mode == "Upload CSV":
|
| 188 |
if upload_csv is None:
|
| 189 |
-
raise gr.Error("Carica un CSV
|
| 190 |
-
y, used_col = load_series_from_csv(upload_csv.name, csv_column)
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
else:
|
| 194 |
-
y = make_sample_series(n, seed, trend, season_period, season_amp, noise)
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
#
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
| 204 |
prediction_length=int(prediction_length),
|
| 205 |
-
|
| 206 |
-
id_column="id",
|
| 207 |
-
timestamp_column="timestamp",
|
| 208 |
-
target="target",
|
| 209 |
)
|
|
|
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
col_high = pick_quantile_column(pred_df, q_high)
|
| 215 |
-
|
| 216 |
-
# pred_df contains the forecast horizon rows; keep only series_0
|
| 217 |
-
pred_df = pred_df[pred_df["id"] == "series_0"].copy()
|
| 218 |
-
|
| 219 |
-
ts_fcst = pd.to_datetime(pred_df["timestamp"]).to_numpy()
|
| 220 |
-
low = pred_df[col_low].to_numpy(dtype=np.float32)
|
| 221 |
-
median = pred_df[col_med].to_numpy(dtype=np.float32)
|
| 222 |
-
high = pred_df[col_high].to_numpy(dtype=np.float32)
|
| 223 |
|
| 224 |
-
#
|
| 225 |
-
t_hist = np.arange(len(y))
|
| 226 |
t_fcst = np.arange(len(y), len(y) + int(prediction_length))
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
f"q{q_high:.2f}": high,
|
| 247 |
-
}
|
| 248 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
info = {
|
| 254 |
-
"
|
| 255 |
-
"device": device,
|
| 256 |
-
"source": source,
|
| 257 |
"history_points": int(len(y)),
|
| 258 |
"prediction_length": int(prediction_length),
|
| 259 |
-
"
|
| 260 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
}
|
| 262 |
|
| 263 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
|
| 266 |
# =========================
|
| 267 |
# UI
|
| 268 |
# =========================
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
gr.Markdown(
|
| 271 |
-
"
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
"Questa versione usa **predict_df()** (API consigliata per Chronos-2) e calcola direttamente i **quantili**. "
|
| 276 |
)
|
| 277 |
|
| 278 |
with gr.Row():
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
run_btn.click(
|
| 313 |
-
fn=
|
| 314 |
inputs=[
|
| 315 |
-
input_mode,
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
q_high,
|
| 328 |
-
device_ui,
|
| 329 |
-
model_id,
|
| 330 |
],
|
| 331 |
-
outputs=[plot, table, download, info],
|
| 332 |
)
|
| 333 |
|
| 334 |
demo.queue()
|
|
|
|
| 1 |
import os
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
import numpy as np
|
| 8 |
import pandas as pd
|
| 9 |
import gradio as gr
|
|
|
|
| 10 |
import torch
|
| 11 |
+
import plotly.graph_objects as go
|
| 12 |
|
| 13 |
from chronos import Chronos2Pipeline
|
| 14 |
|
|
|
|
| 18 |
# =========================
|
| 19 |
MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
|
| 20 |
DATA_DIR = "data"
|
| 21 |
+
OUT_DIR = "/tmp"
|
| 22 |
|
| 23 |
|
| 24 |
# =========================
|
| 25 |
+
# Utilities
|
| 26 |
# =========================
|
| 27 |
+
def available_test_csv() -> list[str]:
|
| 28 |
if not os.path.isdir(DATA_DIR):
|
| 29 |
return []
|
| 30 |
+
return sorted([f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv")])
|
| 31 |
|
| 32 |
|
| 33 |
def pick_device(ui_choice: str) -> str:
|
| 34 |
+
if ui_choice and ui_choice.startswith("cuda") and torch.cuda.is_available():
|
| 35 |
return "cuda"
|
| 36 |
return "cpu"
|
| 37 |
|
| 38 |
|
| 39 |
+
def safe_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 40 |
+
denom = np.maximum(1e-8, np.abs(y_true))
|
| 41 |
+
return float(np.mean(np.abs((y_true - y_pred) / denom)) * 100.0)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 45 |
+
return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 49 |
+
return float(np.mean(np.abs(y_true - y_pred)))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def coverage(y_true: np.ndarray, low: np.ndarray, high: np.ndarray) -> float:
|
| 53 |
+
inside = (y_true >= low) & (y_true <= high)
|
| 54 |
+
return float(np.mean(inside) * 100.0)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def interval_width(low: np.ndarray, high: np.ndarray) -> float:
|
| 58 |
+
return float(np.mean(high - low))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def format_kpi(label: str, value: str, hint: str = "") -> str:
|
| 62 |
+
# Simple “card” layout via HTML
|
| 63 |
+
hint_html = f"<div style='opacity:.75;font-size:12px;margin-top:6px;'>{hint}</div>" if hint else ""
|
| 64 |
+
return f"""
|
| 65 |
+
<div style="
|
| 66 |
+
border:1px solid rgba(255,255,255,.12);
|
| 67 |
+
border-radius:16px;
|
| 68 |
+
padding:14px 16px;
|
| 69 |
+
background: rgba(255,255,255,.04);
|
| 70 |
+
backdrop-filter: blur(6px);
|
| 71 |
+
">
|
| 72 |
+
<div style="font-size:12px;opacity:.8;">{label}</div>
|
| 73 |
+
<div style="font-size:22px;font-weight:700;margin-top:4px;">{value}</div>
|
| 74 |
+
{hint_html}
|
| 75 |
+
</div>
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def make_sample_series(
|
| 80 |
+
n: int,
|
| 81 |
+
seed: int,
|
| 82 |
+
trend: float,
|
| 83 |
+
season_period: int,
|
| 84 |
+
season_amp: float,
|
| 85 |
+
noise: float,
|
| 86 |
+
positive_shift: bool = True,
|
| 87 |
+
) -> np.ndarray:
|
| 88 |
rng = np.random.default_rng(int(seed))
|
| 89 |
+
t = np.arange(int(n), dtype=np.float32)
|
| 90 |
y = (
|
| 91 |
float(trend) * t
|
| 92 |
+ float(season_amp) * np.sin(2 * np.pi * t / max(1, int(season_period)))
|
| 93 |
+
+ rng.normal(0.0, float(noise), size=int(n))
|
| 94 |
+
).astype(np.float32)
|
| 95 |
+
|
| 96 |
+
if positive_shift and float(np.min(y)) < 0:
|
| 97 |
+
y = y - float(np.min(y))
|
| 98 |
+
return y
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
+
def load_series_from_csv(csv_path: str, column: Optional[str]) -> Tuple[np.ndarray, str, pd.DataFrame]:
|
| 102 |
+
df = pd.read_csv(csv_path)
|
| 103 |
if df.shape[1] == 0:
|
| 104 |
raise ValueError("CSV vuoto o non leggibile.")
|
| 105 |
|
| 106 |
col = (column or "").strip()
|
| 107 |
+
if not col:
|
|
|
|
| 108 |
numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
|
|
|
|
| 109 |
if not numeric_cols:
|
| 110 |
+
# Try coercion: maybe numeric stored as strings
|
| 111 |
for c in df.columns:
|
| 112 |
coerced = pd.to_numeric(df[c], errors="coerce")
|
| 113 |
+
if coerced.notna().sum() > 0:
|
| 114 |
numeric_cols.append(c)
|
| 115 |
+
if not numeric_cols:
|
| 116 |
+
raise ValueError("Non trovo colonne numeriche nel CSV.")
|
| 117 |
col = numeric_cols[0]
|
| 118 |
|
| 119 |
if col not in df.columns:
|
| 120 |
+
raise ValueError(f"Colonna '{col}' non trovata. Disponibili: {list(df.columns)}")
|
| 121 |
|
| 122 |
+
y = pd.to_numeric(df[col], errors="coerce").dropna().astype(np.float32).to_numpy()
|
| 123 |
if len(y) < 10:
|
| 124 |
+
raise ValueError("Serie troppo corta (minimo consigliato: 10 punti).")
|
| 125 |
|
| 126 |
+
return y, col, df
|
| 127 |
|
| 128 |
|
| 129 |
# =========================
|
| 130 |
+
# Model cache
|
| 131 |
# =========================
|
| 132 |
_PIPELINE = None
|
| 133 |
+
_PIPELINE_META = {"model_id": None, "device": None}
|
| 134 |
|
| 135 |
|
| 136 |
+
def get_pipeline(model_id: str, device: str) -> Chronos2Pipeline:
|
| 137 |
global _PIPELINE, _PIPELINE_META
|
| 138 |
|
| 139 |
model_id = (model_id or MODEL_ID_DEFAULT).strip()
|
| 140 |
+
device = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu"
|
| 141 |
+
|
| 142 |
+
if _PIPELINE is None or _PIPELINE_META["model_id"] != model_id or _PIPELINE_META["device"] != device:
|
| 143 |
+
pipe = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
|
| 144 |
+
_PIPELINE = pipe
|
|
|
|
|
|
|
|
|
|
| 145 |
_PIPELINE_META = {"model_id": model_id, "device": device}
|
| 146 |
|
| 147 |
return _PIPELINE
|
| 148 |
|
| 149 |
|
| 150 |
# =========================
|
| 151 |
+
# Plotly helpers
|
| 152 |
# =========================
|
| 153 |
+
def plot_forecast_interactive(
|
| 154 |
+
y: np.ndarray,
|
| 155 |
+
median: np.ndarray,
|
| 156 |
+
low: np.ndarray,
|
| 157 |
+
high: np.ndarray,
|
| 158 |
+
title: str,
|
| 159 |
+
q_low: float,
|
| 160 |
+
q_high: float,
|
| 161 |
+
) -> go.Figure:
|
| 162 |
+
t_hist = np.arange(len(y))
|
| 163 |
+
t_fcst = np.arange(len(y), len(y) + len(median))
|
| 164 |
+
|
| 165 |
+
fig = go.Figure()
|
| 166 |
+
|
| 167 |
+
fig.add_trace(go.Scatter(
|
| 168 |
+
x=t_hist, y=y, mode="lines",
|
| 169 |
+
name="History",
|
| 170 |
+
hovertemplate="t=%{x}<br>y=%{y:.4f}<extra></extra>"
|
| 171 |
+
))
|
| 172 |
+
|
| 173 |
+
# Upper bound (invisible line), then lower bound with fill to create band
|
| 174 |
+
fig.add_trace(go.Scatter(
|
| 175 |
+
x=t_fcst, y=high, mode="lines",
|
| 176 |
+
name="Upper",
|
| 177 |
+
line=dict(width=0),
|
| 178 |
+
showlegend=False,
|
| 179 |
+
hoverinfo="skip"
|
| 180 |
+
))
|
| 181 |
+
fig.add_trace(go.Scatter(
|
| 182 |
+
x=t_fcst, y=low, mode="lines",
|
| 183 |
+
name=f"Band [{q_low:.2f}, {q_high:.2f}]",
|
| 184 |
+
fill="tonexty",
|
| 185 |
+
line=dict(width=0),
|
| 186 |
+
hovertemplate="t=%{x}<br>low=%{y:.4f}<extra></extra>"
|
| 187 |
+
))
|
| 188 |
+
|
| 189 |
+
fig.add_trace(go.Scatter(
|
| 190 |
+
x=t_fcst, y=median, mode="lines",
|
| 191 |
+
name="Forecast (median)",
|
| 192 |
+
hovertemplate="t=%{x}<br>median=%{y:.4f}<extra></extra>"
|
| 193 |
+
))
|
| 194 |
+
|
| 195 |
+
fig.add_vline(x=len(y) - 1, line_width=1, line_dash="dash", opacity=0.6)
|
| 196 |
+
|
| 197 |
+
fig.update_layout(
|
| 198 |
+
title=title,
|
| 199 |
+
hovermode="x unified",
|
| 200 |
+
margin=dict(l=10, r=10, t=55, b=10),
|
| 201 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
|
| 202 |
+
xaxis_title="t",
|
| 203 |
+
yaxis_title="value",
|
| 204 |
)
|
| 205 |
+
return fig
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def plot_backtest_interactive(
|
| 209 |
+
y_train: np.ndarray,
|
| 210 |
+
y_true: np.ndarray,
|
| 211 |
+
pred_median: np.ndarray,
|
| 212 |
+
low: np.ndarray,
|
| 213 |
+
high: np.ndarray,
|
| 214 |
+
q_low: float,
|
| 215 |
+
q_high: float,
|
| 216 |
+
) -> go.Figure:
|
| 217 |
+
t_train = np.arange(len(y_train))
|
| 218 |
+
t_test = np.arange(len(y_train), len(y_train) + len(y_true))
|
| 219 |
+
|
| 220 |
+
fig = go.Figure()
|
| 221 |
+
fig.add_trace(go.Scatter(x=t_train, y=y_train, mode="lines", name="Train"))
|
| 222 |
+
fig.add_trace(go.Scatter(x=t_test, y=y_true, mode="lines", name="True (holdout)"))
|
| 223 |
+
|
| 224 |
+
fig.add_trace(go.Scatter(x=t_test, y=high, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip"))
|
| 225 |
+
fig.add_trace(go.Scatter(
|
| 226 |
+
x=t_test, y=low, mode="lines",
|
| 227 |
+
fill="tonexty", line=dict(width=0),
|
| 228 |
+
name=f"Band [{q_low:.2f}, {q_high:.2f}]"
|
| 229 |
+
))
|
| 230 |
+
fig.add_trace(go.Scatter(x=t_test, y=pred_median, mode="lines", name="Pred (median)"))
|
| 231 |
+
|
| 232 |
+
fig.add_vline(x=len(y_train) - 1, line_width=1, line_dash="dash", opacity=0.6)
|
| 233 |
+
|
| 234 |
+
fig.update_layout(
|
| 235 |
+
title="Backtest (holdout) — interactive",
|
| 236 |
+
hovermode="x unified",
|
| 237 |
+
margin=dict(l=10, r=10, t=55, b=10),
|
| 238 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
|
| 239 |
+
xaxis_title="t",
|
| 240 |
+
yaxis_title="value",
|
| 241 |
+
)
|
| 242 |
+
return fig
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def plot_sample_distribution(samples: np.ndarray) -> go.Figure:
|
| 246 |
+
# show distribution for a few horizons
|
| 247 |
+
if samples.ndim != 2:
|
| 248 |
+
samples = np.asarray(samples)
|
| 249 |
+
n_h = samples.shape[1]
|
| 250 |
+
idxs = []
|
| 251 |
+
for frac in [0.1, 0.5, 0.9]:
|
| 252 |
+
i = int(round((n_h - 1) * frac))
|
| 253 |
+
idxs.append(i)
|
| 254 |
+
idxs = sorted(set(idxs))
|
| 255 |
+
|
| 256 |
+
fig = go.Figure()
|
| 257 |
+
for i in idxs:
|
| 258 |
+
fig.add_trace(go.Histogram(
|
| 259 |
+
x=samples[:, i],
|
| 260 |
+
name=f"h={i+1}",
|
| 261 |
+
opacity=0.6
|
| 262 |
+
))
|
| 263 |
+
fig.update_layout(
|
| 264 |
+
barmode="overlay",
|
| 265 |
+
title="Forecast sample distributions (selected horizons)",
|
| 266 |
+
margin=dict(l=10, r=10, t=55, b=10),
|
| 267 |
+
xaxis_title="value",
|
| 268 |
+
yaxis_title="count",
|
| 269 |
+
)
|
| 270 |
+
return fig
|
| 271 |
|
| 272 |
|
| 273 |
# =========================
|
| 274 |
+
# Core run
|
| 275 |
# =========================
|
| 276 |
+
@dataclass
|
| 277 |
+
class RunResult:
|
| 278 |
+
forecast_fig: go.Figure
|
| 279 |
+
backtest_fig: Optional[go.Figure]
|
| 280 |
+
dist_fig: go.Figure
|
| 281 |
+
forecast_table: pd.DataFrame
|
| 282 |
+
backtest_table: Optional[pd.DataFrame]
|
| 283 |
+
forecast_csv_path: str
|
| 284 |
+
backtest_csv_path: Optional[str]
|
| 285 |
+
kpi_html: str
|
| 286 |
+
info: dict
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def run_dashboard(
|
| 290 |
+
input_mode: str,
|
| 291 |
+
test_csv_name: str,
|
| 292 |
upload_csv,
|
| 293 |
+
csv_column: str,
|
| 294 |
+
|
| 295 |
+
# sample params
|
| 296 |
+
n: int,
|
| 297 |
+
seed: int,
|
| 298 |
+
trend: float,
|
| 299 |
+
season_period: int,
|
| 300 |
+
season_amp: float,
|
| 301 |
+
noise: float,
|
| 302 |
+
|
| 303 |
+
# forecast params
|
| 304 |
+
prediction_length: int,
|
| 305 |
+
num_samples: int,
|
| 306 |
+
q_low: float,
|
| 307 |
+
q_high: float,
|
| 308 |
+
|
| 309 |
+
# backtest
|
| 310 |
+
do_backtest: bool,
|
| 311 |
+
holdout: int,
|
| 312 |
+
|
| 313 |
+
# system
|
| 314 |
+
device_ui: str,
|
| 315 |
+
model_id: str,
|
| 316 |
+
) -> RunResult:
|
| 317 |
+
|
| 318 |
if q_low >= q_high:
|
| 319 |
raise gr.Error("Quantile low deve essere < quantile high.")
|
| 320 |
|
| 321 |
device = pick_device(device_ui)
|
|
|
|
| 322 |
|
| 323 |
+
# ---------
|
| 324 |
+
# Load series
|
| 325 |
+
# ---------
|
| 326 |
if input_mode == "Test CSV":
|
| 327 |
if not test_csv_name:
|
| 328 |
+
raise gr.Error("Seleziona un Test CSV dalla dropdown.")
|
| 329 |
+
csv_path = os.path.join(DATA_DIR, test_csv_name)
|
| 330 |
+
if not os.path.exists(csv_path):
|
| 331 |
+
raise gr.Error(f"File non trovato: {csv_path}")
|
| 332 |
+
y, used_col, df_preview = load_series_from_csv(csv_path, csv_column)
|
| 333 |
+
source_title = f"Test CSV: {test_csv_name} • col={used_col}"
|
| 334 |
|
| 335 |
elif input_mode == "Upload CSV":
|
| 336 |
if upload_csv is None:
|
| 337 |
+
raise gr.Error("Carica un CSV (Upload CSV) oppure cambia modalità.")
|
| 338 |
+
y, used_col, df_preview = load_series_from_csv(upload_csv.name, csv_column)
|
| 339 |
+
source_title = f"Upload CSV • col={used_col}"
|
| 340 |
+
|
| 341 |
+
else:
|
| 342 |
+
y = make_sample_series(n, seed, trend, season_period, season_amp, noise, positive_shift=True)
|
| 343 |
+
df_preview = pd.DataFrame({"value": y})
|
| 344 |
+
used_col = "value"
|
| 345 |
+
source_title = "Sample series"
|
| 346 |
+
|
| 347 |
+
if do_backtest and holdout >= len(y):
|
| 348 |
+
raise gr.Error("Holdout deve essere più piccolo della lunghezza dello storico.")
|
| 349 |
+
|
| 350 |
+
# ---------
|
| 351 |
+
# Model
|
| 352 |
+
# ---------
|
| 353 |
+
t0 = time.time()
|
| 354 |
+
pipe = get_pipeline(model_id, device)
|
| 355 |
|
| 356 |
+
# ---------
|
| 357 |
+
# Forecast
|
| 358 |
+
# ---------
|
| 359 |
+
samples = pipe.predict(
|
| 360 |
+
context=y.tolist(),
|
| 361 |
prediction_length=int(prediction_length),
|
| 362 |
+
num_samples=int(num_samples),
|
|
|
|
|
|
|
|
|
|
| 363 |
)
|
| 364 |
+
samples = np.asarray(samples, dtype=np.float32)
|
| 365 |
|
| 366 |
+
median = np.quantile(samples, 0.50, axis=0)
|
| 367 |
+
low = np.quantile(samples, float(q_low), axis=0)
|
| 368 |
+
high = np.quantile(samples, float(q_high), axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
+
# Tables
|
|
|
|
| 371 |
t_fcst = np.arange(len(y), len(y) + int(prediction_length))
|
| 372 |
+
forecast_df = pd.DataFrame({
|
| 373 |
+
"t": t_fcst,
|
| 374 |
+
"median": median,
|
| 375 |
+
f"q{q_low:.2f}": low,
|
| 376 |
+
f"q{q_high:.2f}": high,
|
| 377 |
+
})
|
| 378 |
+
|
| 379 |
+
forecast_csv_path = os.path.join(OUT_DIR, "chronos2_forecast.csv")
|
| 380 |
+
forecast_df.to_csv(forecast_csv_path, index=False)
|
| 381 |
+
|
| 382 |
+
# Plots
|
| 383 |
+
forecast_fig = plot_forecast_interactive(
|
| 384 |
+
y=y,
|
| 385 |
+
median=median,
|
| 386 |
+
low=low,
|
| 387 |
+
high=high,
|
| 388 |
+
title=f"Forecast — {source_title}",
|
| 389 |
+
q_low=q_low,
|
| 390 |
+
q_high=q_high,
|
|
|
|
|
|
|
| 391 |
)
|
| 392 |
+
dist_fig = plot_sample_distribution(samples)
|
| 393 |
+
|
| 394 |
+
# ---------
|
| 395 |
+
# Backtest (optional)
|
| 396 |
+
# ---------
|
| 397 |
+
backtest_fig = None
|
| 398 |
+
backtest_df = None
|
| 399 |
+
backtest_csv_path = None
|
| 400 |
+
|
| 401 |
+
kpi_items = []
|
| 402 |
+
# Always show run/system KPIs
|
| 403 |
+
elapsed = time.time() - t0
|
| 404 |
+
kpi_items.append(format_kpi("Device", device.upper(), f"torch.cuda={torch.cuda.is_available()}"))
|
| 405 |
+
kpi_items.append(format_kpi("Model", (model_id or MODEL_ID_DEFAULT), "Chronos-2 pipeline"))
|
| 406 |
+
kpi_items.append(format_kpi("Latency", f"{elapsed:.2f}s", "model load cached after first run"))
|
| 407 |
+
kpi_items.append(format_kpi("Samples", f"{int(num_samples)}", "more = smoother quantiles"))
|
| 408 |
+
|
| 409 |
+
# Coverage/width (forecast only) – informational
|
| 410 |
+
kpi_items.append(format_kpi("Interval", f"[{q_low:.2f}, {q_high:.2f}]", "uncertainty band"))
|
| 411 |
+
kpi_items.append(format_kpi("Avg band width", f"{interval_width(low, high):.3f}", "forecast band only"))
|
| 412 |
+
|
| 413 |
+
if do_backtest:
|
| 414 |
+
y_train = y[:-int(holdout)]
|
| 415 |
+
y_true = y[-int(holdout):]
|
| 416 |
+
|
| 417 |
+
bt_samples = pipe.predict(
|
| 418 |
+
context=y_train.tolist(),
|
| 419 |
+
prediction_length=int(holdout),
|
| 420 |
+
num_samples=int(num_samples),
|
| 421 |
+
)
|
| 422 |
+
bt_samples = np.asarray(bt_samples, dtype=np.float32)
|
| 423 |
+
bt_median = np.quantile(bt_samples, 0.50, axis=0)
|
| 424 |
+
bt_low = np.quantile(bt_samples, float(q_low), axis=0)
|
| 425 |
+
bt_high = np.quantile(bt_samples, float(q_high), axis=0)
|
| 426 |
+
|
| 427 |
+
# Metrics
|
| 428 |
+
bt_mae = mae(y_true, bt_median)
|
| 429 |
+
bt_rmse = rmse(y_true, bt_median)
|
| 430 |
+
bt_mape = safe_mape(y_true, bt_median)
|
| 431 |
+
bt_cov = coverage(y_true, bt_low, bt_high)
|
| 432 |
+
bt_w = interval_width(bt_low, bt_high)
|
| 433 |
+
|
| 434 |
+
kpi_items.append(format_kpi("Backtest MAE", f"{bt_mae:.3f}", f"holdout={holdout}"))
|
| 435 |
+
kpi_items.append(format_kpi("Backtest RMSE", f"{bt_rmse:.3f}", ""))
|
| 436 |
+
kpi_items.append(format_kpi("Backtest MAPE", f"{bt_mape:.2f}%", ""))
|
| 437 |
+
kpi_items.append(format_kpi("Coverage", f"{bt_cov:.1f}%", "inside band"))
|
| 438 |
+
kpi_items.append(format_kpi("Backtest width", f"{bt_w:.3f}", "avg band width"))
|
| 439 |
+
|
| 440 |
+
backtest_fig = plot_backtest_interactive(
|
| 441 |
+
y_train=y_train,
|
| 442 |
+
y_true=y_true,
|
| 443 |
+
pred_median=bt_median,
|
| 444 |
+
low=bt_low,
|
| 445 |
+
high=bt_high,
|
| 446 |
+
q_low=q_low,
|
| 447 |
+
q_high=q_high,
|
| 448 |
+
)
|
| 449 |
|
| 450 |
+
t_test = np.arange(len(y_train), len(y_train) + int(holdout))
|
| 451 |
+
backtest_df = pd.DataFrame({
|
| 452 |
+
"t": t_test,
|
| 453 |
+
"true": y_true,
|
| 454 |
+
"pred_median": bt_median,
|
| 455 |
+
f"q{q_low:.2f}": bt_low,
|
| 456 |
+
f"q{q_high:.2f}": bt_high,
|
| 457 |
+
})
|
| 458 |
+
backtest_csv_path = os.path.join(OUT_DIR, "chronos2_backtest.csv")
|
| 459 |
+
backtest_df.to_csv(backtest_csv_path, index=False)
|
| 460 |
+
|
| 461 |
+
kpi_html = f"""
|
| 462 |
+
<div style="display:grid; grid-template-columns: repeat(6, minmax(0, 1fr)); gap:12px;">
|
| 463 |
+
{''.join(kpi_items)}
|
| 464 |
+
</div>
|
| 465 |
+
"""
|
| 466 |
|
| 467 |
info = {
|
| 468 |
+
"source": source_title,
|
|
|
|
|
|
|
| 469 |
"history_points": int(len(y)),
|
| 470 |
"prediction_length": int(prediction_length),
|
| 471 |
+
"num_samples": int(num_samples),
|
| 472 |
+
"q_low": float(q_low),
|
| 473 |
+
"q_high": float(q_high),
|
| 474 |
+
"backtest": bool(do_backtest),
|
| 475 |
+
"holdout": int(holdout) if do_backtest else None,
|
| 476 |
+
"column_used": used_col,
|
| 477 |
}
|
| 478 |
|
| 479 |
+
return RunResult(
|
| 480 |
+
forecast_fig=forecast_fig,
|
| 481 |
+
backtest_fig=backtest_fig,
|
| 482 |
+
dist_fig=dist_fig,
|
| 483 |
+
forecast_table=forecast_df,
|
| 484 |
+
backtest_table=backtest_df,
|
| 485 |
+
forecast_csv_path=forecast_csv_path,
|
| 486 |
+
backtest_csv_path=backtest_csv_path,
|
| 487 |
+
kpi_html=kpi_html,
|
| 488 |
+
info=info,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def run_dashboard_wrapped(*args):
|
| 493 |
+
res = run_dashboard(*args)
|
| 494 |
+
# outputs must be basic objects
|
| 495 |
+
# If no backtest, send an empty placeholder plot and empty table/file
|
| 496 |
+
empty_fig = go.Figure().update_layout(
|
| 497 |
+
title="Backtest disabled",
|
| 498 |
+
margin=dict(l=10, r=10, t=55, b=10),
|
| 499 |
+
)
|
| 500 |
+
empty_df = pd.DataFrame()
|
| 501 |
+
|
| 502 |
+
return (
|
| 503 |
+
res.kpi_html,
|
| 504 |
+
res.forecast_fig,
|
| 505 |
+
(res.backtest_fig if res.backtest_fig is not None else empty_fig),
|
| 506 |
+
res.dist_fig,
|
| 507 |
+
res.forecast_table,
|
| 508 |
+
(res.backtest_table if res.backtest_table is not None else empty_df),
|
| 509 |
+
res.forecast_csv_path,
|
| 510 |
+
(res.backtest_csv_path if res.backtest_csv_path is not None else None),
|
| 511 |
+
res.info,
|
| 512 |
+
)
|
| 513 |
|
| 514 |
|
| 515 |
# =========================
|
| 516 |
# UI
|
| 517 |
# =========================
|
| 518 |
+
css = """
|
| 519 |
+
:root { --radius: 18px; }
|
| 520 |
+
.gradio-container { max-width: 1200px !important; }
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
with gr.Blocks(title="Chronos-2 • Forecast Dashboard", css=css) as demo:
|
| 524 |
gr.Markdown(
|
| 525 |
+
"""
|
| 526 |
+
# ⏱️ Chronos-2 Forecast Dashboard
|
| 527 |
+
Una dashboard interattiva (Plotly) per testare **Amazon Chronos-2** su serie storiche (sample / CSV / upload), con **bande di incertezza** e **backtest**.
|
| 528 |
+
"""
|
|
|
|
| 529 |
)
|
| 530 |
|
| 531 |
with gr.Row():
|
| 532 |
+
with gr.Column(scale=1, min_width=360):
|
| 533 |
+
gr.Markdown("## Input")
|
| 534 |
+
|
| 535 |
+
input_mode = gr.Radio(
|
| 536 |
+
["Sample", "Test CSV", "Upload CSV"],
|
| 537 |
+
value="Sample",
|
| 538 |
+
label="Sorgente dati",
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
test_csv_name = gr.Dropdown(
|
| 542 |
+
choices=available_test_csv(),
|
| 543 |
+
value=None,
|
| 544 |
+
label="Test CSV (cartella data/)",
|
| 545 |
+
info="Comparirà qui se metti .csv dentro data/",
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
|
| 549 |
+
csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value")
|
| 550 |
+
|
| 551 |
+
gr.Markdown("## Sistema")
|
| 552 |
+
device_ui = gr.Dropdown(
|
| 553 |
+
["cpu", "cuda (se disponibile)"],
|
| 554 |
+
value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu",
|
| 555 |
+
label="Device",
|
| 556 |
+
)
|
| 557 |
+
model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID")
|
| 558 |
+
|
| 559 |
+
with gr.Accordion("Sample generator", open=False):
|
| 560 |
+
n = gr.Slider(60, 1200, value=300, step=10, label="History length")
|
| 561 |
+
seed = gr.Number(value=42, precision=0, label="Seed")
|
| 562 |
+
trend = gr.Slider(0.0, 0.2, value=0.03, step=0.005, label="Trend")
|
| 563 |
+
season_period = gr.Slider(2, 120, value=14, step=1, label="Season period")
|
| 564 |
+
season_amp = gr.Slider(0.0, 12.0, value=3.0, step=0.1, label="Season amplitude")
|
| 565 |
+
noise = gr.Slider(0.0, 6.0, value=0.8, step=0.05, label="Noise")
|
| 566 |
+
|
| 567 |
+
gr.Markdown("## Forecast settings")
|
| 568 |
+
prediction_length = gr.Slider(1, 365, value=30, step=1, label="Prediction length")
|
| 569 |
+
num_samples = gr.Slider(50, 800, value=300, step=25, label="Num samples (quantili più stabili)")
|
| 570 |
+
q_low = gr.Slider(0.01, 0.49, value=0.10, step=0.01, label="Quantile low")
|
| 571 |
+
q_high = gr.Slider(0.51, 0.99, value=0.90, step=0.01, label="Quantile high")
|
| 572 |
+
|
| 573 |
+
gr.Markdown("## Backtest")
|
| 574 |
+
do_backtest = gr.Checkbox(value=True, label="Esegui backtest holdout")
|
| 575 |
+
holdout = gr.Slider(5, 240, value=30, step=1, label="Holdout points")
|
| 576 |
+
|
| 577 |
+
run_btn = gr.Button("Run", variant="primary")
|
| 578 |
+
|
| 579 |
+
with gr.Column(scale=2):
|
| 580 |
+
gr.Markdown("## KPI")
|
| 581 |
+
kpi_html = gr.HTML()
|
| 582 |
+
|
| 583 |
+
with gr.Tabs():
|
| 584 |
+
with gr.Tab("Forecast"):
|
| 585 |
+
forecast_plot = gr.Plot(label="Interactive forecast (Plotly)")
|
| 586 |
+
forecast_table = gr.Dataframe(label="Forecast table", interactive=False)
|
| 587 |
+
|
| 588 |
+
with gr.Tab("Backtest"):
|
| 589 |
+
backtest_plot = gr.Plot(label="Interactive backtest (Plotly)")
|
| 590 |
+
backtest_table = gr.Dataframe(label="Backtest table", interactive=False)
|
| 591 |
+
|
| 592 |
+
with gr.Tab("Distributions"):
|
| 593 |
+
dist_plot = gr.Plot(label="Sample distributions (selected horizons)")
|
| 594 |
+
|
| 595 |
+
with gr.Tab("Export"):
|
| 596 |
+
gr.Markdown("Scarica i CSV prodotti dall’ultima run:")
|
| 597 |
+
forecast_download = gr.File(label="Forecast CSV")
|
| 598 |
+
backtest_download = gr.File(label="Backtest CSV")
|
| 599 |
+
|
| 600 |
+
with gr.Tab("Run info"):
|
| 601 |
+
run_info = gr.JSON(label="Info")
|
| 602 |
|
| 603 |
run_btn.click(
|
| 604 |
+
fn=run_dashboard_wrapped,
|
| 605 |
inputs=[
|
| 606 |
+
input_mode, test_csv_name, upload_csv, csv_column,
|
| 607 |
+
n, seed, trend, season_period, season_amp, noise,
|
| 608 |
+
prediction_length, num_samples, q_low, q_high,
|
| 609 |
+
do_backtest, holdout,
|
| 610 |
+
device_ui, model_id,
|
| 611 |
+
],
|
| 612 |
+
outputs=[
|
| 613 |
+
kpi_html,
|
| 614 |
+
forecast_plot, backtest_plot, dist_plot,
|
| 615 |
+
forecast_table, backtest_table,
|
| 616 |
+
forecast_download, backtest_download,
|
| 617 |
+
run_info,
|
|
|
|
|
|
|
|
|
|
| 618 |
],
|
|
|
|
| 619 |
)
|
| 620 |
|
| 621 |
demo.queue()
|