bosh94 commited on
Commit
f79bf21
·
verified ·
1 Parent(s): 0411375

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -66
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import numpy as np
3
  import pandas as pd
4
  import gradio as gr
@@ -7,34 +8,35 @@ import torch
7
 
8
  from chronos import Chronos2Pipeline
9
 
10
-
11
  # =========================
12
  # Config
13
  # =========================
14
- MODEL_ID_DEFAULT = "amazon/chronos-2"
15
  DATA_DIR = "data"
16
 
17
-
18
  # =========================
19
- # Utils
20
  # =========================
21
  def available_test_csv():
22
  if not os.path.isdir(DATA_DIR):
23
  return []
24
- return sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".csv"))
25
-
26
 
27
  def pick_device(ui_choice: str) -> str:
28
- if ui_choice.startswith("cuda") and torch.cuda.is_available():
29
  return "cuda"
30
  return "cpu"
31
 
32
-
 
 
33
  _PIPELINE = None
34
  _PIPELINE_META = {}
35
 
36
-
37
  def get_pipeline(model_id: str, device: str):
 
 
 
38
  global _PIPELINE, _PIPELINE_META
39
 
40
  model_id = (model_id or MODEL_ID_DEFAULT).strip()
@@ -45,44 +47,128 @@ def get_pipeline(model_id: str, device: str):
45
  or _PIPELINE_META.get("model_id") != model_id
46
  or _PIPELINE_META.get("device") != device
47
  ):
48
- pipe = Chronos2Pipeline.from_pretrained(
49
- model_id,
50
- device_map=device,
51
- )
52
- _PIPELINE = pipe
53
  _PIPELINE_META = {"model_id": model_id, "device": device}
54
 
55
  return _PIPELINE
56
 
57
-
 
 
58
  def make_sample_series(n, seed, trend, season_period, season_amp, noise):
59
  rng = np.random.default_rng(int(seed))
60
  t = np.arange(int(n))
61
  y = (
62
- trend * t
63
- + season_amp * np.sin(2 * np.pi * t / max(1, int(season_period)))
64
- + rng.normal(0, noise, size=len(t))
65
  )
66
- if y.min() < 0:
67
- y = y - y.min()
 
 
68
  return y.astype(np.float32)
69
 
70
-
71
  def load_series_from_csv(path_or_file, column=None):
72
  df = pd.read_csv(path_or_file)
73
 
74
- if column is None or column.strip() == "":
 
 
 
 
75
  numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
76
  if not numeric_cols:
77
- raise ValueError("Nessuna colonna numerica nel CSV.")
78
- column = numeric_cols[0]
79
-
80
- y = pd.to_numeric(df[column], errors="coerce").dropna().to_numpy()
 
 
 
 
 
 
 
 
 
 
81
  if len(y) < 10:
82
- raise ValueError("Serie troppo corta (minimo ~10 punti).")
83
 
84
- return y.astype(np.float32), column
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # =========================
88
  # Forecast core
@@ -99,59 +185,61 @@ def run_forecast(
99
  season_amp,
100
  noise,
101
  prediction_length,
102
- num_samples,
103
  q_low,
104
  q_high,
105
  device_ui,
106
  model_id,
107
  ):
108
- if q_low >= q_high:
 
109
  raise gr.Error("Quantile low deve essere < quantile high.")
110
 
 
111
  device = pick_device(device_ui)
112
  pipe = get_pipeline(model_id, device)
113
 
114
- # -------------------------
115
- # Input data
116
- # -------------------------
117
- if input_mode == "Test CSV" and test_csv_name:
118
- path = os.path.join(DATA_DIR, test_csv_name)
119
- y, used_col = load_series_from_csv(path, csv_column)
 
 
120
  source = f"Test CSV: {test_csv_name} ({used_col})"
121
 
122
- elif input_mode == "Upload CSV" and upload_csv is not None:
 
 
123
  y, used_col = load_series_from_csv(upload_csv.name, csv_column)
124
  source = f"Upload CSV ({used_col})"
125
 
126
- else:
127
  y = make_sample_series(n, seed, trend, season_period, season_amp, noise)
128
  source = "Sample data"
129
 
130
- # -------------------------
131
- # Forecast
132
- # -------------------------
133
- samples = pipe.predict(
134
- inputs=y.tolist(),
135
  prediction_length=int(prediction_length),
136
- num_samples=int(num_samples),
137
  )
138
- samples = np.asarray(samples, dtype=np.float32)
139
-
140
 
 
141
  median = np.quantile(samples, 0.50, axis=0)
142
- low = np.quantile(samples, q_low, axis=0)
143
- high = np.quantile(samples, q_high, axis=0)
144
 
145
- # -------------------------
146
  # Plot
147
- # -------------------------
148
  t_hist = np.arange(len(y))
149
  t_fcst = np.arange(len(y), len(y) + int(prediction_length))
150
 
151
  fig, ax = plt.subplots(figsize=(10, 4))
152
  ax.plot(t_hist, y, label="history")
153
  ax.plot(t_fcst, median, label="forecast (median)")
154
- ax.fill_between(t_fcst, low, high, alpha=0.25, label="confidence band")
155
  ax.axvline(len(y) - 1, linestyle="--", linewidth=1)
156
  ax.set_title(source)
157
  ax.set_xlabel("t")
@@ -159,15 +247,13 @@ def run_forecast(
159
  ax.grid(True, alpha=0.3)
160
  ax.legend()
161
 
162
- # -------------------------
163
- # Output
164
- # -------------------------
165
  out_df = pd.DataFrame(
166
  {
167
  "t": t_fcst,
168
  "median": median,
169
- f"q{q_low:.2f}": low,
170
- f"q{q_high:.2f}": high,
171
  }
172
  )
173
 
@@ -175,22 +261,23 @@ def run_forecast(
175
  out_df.to_csv(out_path, index=False)
176
 
177
  info = {
178
- "model_id": model_id,
179
  "device": device,
180
  "source": source,
181
- "history_points": len(y),
182
- "prediction_length": prediction_length,
183
- "num_samples": num_samples,
 
184
  }
185
 
186
  return fig, out_df, out_path, info
187
 
188
-
189
  # =========================
190
  # UI
191
  # =========================
192
  with gr.Blocks(title="Chronos-2 • HF Spaces Demo") as demo:
193
- gr.Markdown("# ⏱️ Chronos-2 Forecast Demo")
 
194
 
195
  with gr.Row():
196
  input_mode = gr.Radio(
@@ -208,10 +295,10 @@ with gr.Blocks(title="Chronos-2 • HF Spaces Demo") as demo:
208
  with gr.Row():
209
  test_csv_name = gr.Dropdown(
210
  choices=available_test_csv(),
211
- label="Test CSV disponibili",
212
  )
213
  upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
214
- csv_column = gr.Textbox(label="Colonna numerica (opzionale)")
215
 
216
  with gr.Accordion("Sample data settings", open=False):
217
  n = gr.Slider(60, 600, 220, step=10, label="History length")
@@ -223,7 +310,8 @@ with gr.Blocks(title="Chronos-2 • HF Spaces Demo") as demo:
223
 
224
  with gr.Accordion("Forecast settings", open=True):
225
  prediction_length = gr.Slider(1, 180, 30, step=1, label="Prediction length")
226
- num_samples = gr.Slider(20, 400, 200, step=10, label="Num samples")
 
227
  q_low = gr.Slider(0.01, 0.49, 0.10, step=0.01, label="Quantile low")
228
  q_high = gr.Slider(0.51, 0.99, 0.90, step=0.01, label="Quantile high")
229
 
@@ -248,7 +336,7 @@ with gr.Blocks(title="Chronos-2 • HF Spaces Demo") as demo:
248
  season_amp,
249
  noise,
250
  prediction_length,
251
- num_samples,
252
  q_low,
253
  q_high,
254
  device_ui,
 
1
  import os
2
+ import inspect
3
  import numpy as np
4
  import pandas as pd
5
  import gradio as gr
 
8
 
9
  from chronos import Chronos2Pipeline
10
 
 
11
  # =========================
12
  # Config
13
  # =========================
14
+ MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
15
  DATA_DIR = "data"
16
 
 
17
  # =========================
18
+ # Helpers: files & device
19
  # =========================
20
  def available_test_csv():
21
  if not os.path.isdir(DATA_DIR):
22
  return []
23
+ return sorted(f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv"))
 
24
 
25
  def pick_device(ui_choice: str) -> str:
26
+ if (ui_choice or "").startswith("cuda") and torch.cuda.is_available():
27
  return "cuda"
28
  return "cpu"
29
 
30
+ # =========================
31
+ # Model cache
32
+ # =========================
33
  _PIPELINE = None
34
  _PIPELINE_META = {}
35
 
 
36
  def get_pipeline(model_id: str, device: str):
37
+ """
38
+ Caches the pipeline across calls to avoid re-downloading and re-loading.
39
+ """
40
  global _PIPELINE, _PIPELINE_META
41
 
42
  model_id = (model_id or MODEL_ID_DEFAULT).strip()
 
47
  or _PIPELINE_META.get("model_id") != model_id
48
  or _PIPELINE_META.get("device") != device
49
  ):
50
+ # Chronos-2 pipeline
51
+ _PIPELINE = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
 
 
 
52
  _PIPELINE_META = {"model_id": model_id, "device": device}
53
 
54
  return _PIPELINE
55
 
56
+ # =========================
57
+ # Data generation/loading
58
+ # =========================
59
  def make_sample_series(n, seed, trend, season_period, season_amp, noise):
60
  rng = np.random.default_rng(int(seed))
61
  t = np.arange(int(n))
62
  y = (
63
+ float(trend) * t
64
+ + float(season_amp) * np.sin(2 * np.pi * t / max(1, int(season_period)))
65
+ + rng.normal(0.0, float(noise), size=len(t))
66
  )
67
+ # shift up if negative (not required, but keeps nice plots)
68
+ mn = float(np.min(y))
69
+ if mn < 0:
70
+ y = y - mn
71
  return y.astype(np.float32)
72
 
 
73
  def load_series_from_csv(path_or_file, column=None):
74
  df = pd.read_csv(path_or_file)
75
 
76
+ if df.shape[1] == 0:
77
+ raise ValueError("CSV vuoto o non leggibile.")
78
+
79
+ col = (column or "").strip()
80
+ if col == "":
81
  numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
82
  if not numeric_cols:
83
+ # try coercion to numeric on all columns (sometimes dtype is object)
84
+ numeric_cols = []
85
+ for c in df.columns:
86
+ coerced = pd.to_numeric(df[c], errors="coerce")
87
+ if coerced.notna().sum() >= 10:
88
+ numeric_cols.append(c)
89
+ if not numeric_cols:
90
+ raise ValueError("Nessuna colonna numerica nel CSV. Specifica una colonna con numeri.")
91
+ col = numeric_cols[0]
92
+
93
+ if col not in df.columns:
94
+ raise ValueError(f"Colonna '{col}' non trovata. Colonne: {list(df.columns)}")
95
+
96
+ y = pd.to_numeric(df[col], errors="coerce").dropna().to_numpy()
97
  if len(y) < 10:
98
+ raise ValueError("Serie troppo corta (minimo ~10 punti dopo dropna).")
99
 
100
+ return y.astype(np.float32), col
101
 
102
+ # =========================
103
+ # Chronos2 predict normalization
104
+ # =========================
105
+ def _extract_samples(pred_out):
106
+ """
107
+ Chronos2Pipeline.predict may return:
108
+ - numpy array / list -> samples
109
+ - dict with 'samples'
110
+ - object with attribute 'samples'
111
+ This returns np.ndarray of shape (n_draws, pred_len) or (pred_len,) if only one draw.
112
+ """
113
+ if isinstance(pred_out, np.ndarray):
114
+ return pred_out
115
+ if isinstance(pred_out, list):
116
+ return np.asarray(pred_out)
117
+ if isinstance(pred_out, dict):
118
+ if "samples" in pred_out:
119
+ return np.asarray(pred_out["samples"])
120
+ # sometimes "forecast" keys etc.
121
+ for k in ("predictions", "prediction", "outputs"):
122
+ if k in pred_out:
123
+ return np.asarray(pred_out[k])
124
+ return np.asarray(pred_out)
125
+ # object with samples attribute
126
+ if hasattr(pred_out, "samples"):
127
+ return np.asarray(getattr(pred_out, "samples"))
128
+ # last resort
129
+ return np.asarray(pred_out)
130
+
131
+ def chronos2_predict_samples(pipe, y, prediction_length: int, n_draws: int):
132
+ """
133
+ Calls pipe.predict in a robust way across Chronos versions:
134
+ - Uses `inputs=` (required)
135
+ - Uses `num_predictions=` if supported
136
+ - If not supported, falls back to a single prediction and returns shape (1, pred_len)
137
+ """
138
+ sig = inspect.signature(pipe.predict)
139
+ params = sig.parameters
140
+
141
+ kwargs = {"inputs": y.tolist(), "prediction_length": int(prediction_length)}
142
+
143
+ # API differences: some versions accept num_predictions, others not
144
+ if "num_predictions" in params:
145
+ kwargs["num_predictions"] = int(n_draws)
146
+
147
+ # Some versions might have different names; try a couple safe fallbacks
148
+ try:
149
+ out = pipe.predict(**kwargs)
150
+ except TypeError as e:
151
+ # If num_predictions was rejected, retry without it
152
+ if "num_predictions" in kwargs:
153
+ kwargs.pop("num_predictions", None)
154
+ out = pipe.predict(**kwargs)
155
+ else:
156
+ raise e
157
+
158
+ samples = _extract_samples(out).astype(np.float32)
159
+
160
+ # Normalize shape: expected (n_draws, pred_len)
161
+ if samples.ndim == 1:
162
+ samples = samples[None, :]
163
+ elif samples.ndim == 2:
164
+ pass
165
+ else:
166
+ # If extra dims, squeeze conservatively
167
+ samples = np.squeeze(samples)
168
+ if samples.ndim == 1:
169
+ samples = samples[None, :]
170
+
171
+ return samples
172
 
173
  # =========================
174
  # Forecast core
 
185
  season_amp,
186
  noise,
187
  prediction_length,
188
+ num_draws,
189
  q_low,
190
  q_high,
191
  device_ui,
192
  model_id,
193
  ):
194
+ # Validate quantiles
195
+ if float(q_low) >= float(q_high):
196
  raise gr.Error("Quantile low deve essere < quantile high.")
197
 
198
+ # Device + pipeline
199
  device = pick_device(device_ui)
200
  pipe = get_pipeline(model_id, device)
201
 
202
+ # Choose input series
203
+ if input_mode == "Test CSV":
204
+ if not test_csv_name:
205
+ raise gr.Error("Seleziona un file nella dropdown dei Test CSV oppure usa Sample/Upload.")
206
+ csv_path = os.path.join(DATA_DIR, test_csv_name)
207
+ if not os.path.exists(csv_path):
208
+ raise gr.Error(f"Non trovo {csv_path}. Assicurati che esista nel repo dello Space.")
209
+ y, used_col = load_series_from_csv(csv_path, csv_column)
210
  source = f"Test CSV: {test_csv_name} ({used_col})"
211
 
212
+ elif input_mode == "Upload CSV":
213
+ if upload_csv is None:
214
+ raise gr.Error("Carica un CSV oppure scegli Sample/Test CSV.")
215
  y, used_col = load_series_from_csv(upload_csv.name, csv_column)
216
  source = f"Upload CSV ({used_col})"
217
 
218
+ else: # Sample
219
  y = make_sample_series(n, seed, trend, season_period, season_amp, noise)
220
  source = "Sample data"
221
 
222
+ # Forecast samples
223
+ samples = chronos2_predict_samples(
224
+ pipe=pipe,
225
+ y=y,
 
226
  prediction_length=int(prediction_length),
227
+ n_draws=int(num_draws),
228
  )
 
 
229
 
230
+ # Quantiles
231
  median = np.quantile(samples, 0.50, axis=0)
232
+ low = np.quantile(samples, float(q_low), axis=0)
233
+ high = np.quantile(samples, float(q_high), axis=0)
234
 
 
235
  # Plot
 
236
  t_hist = np.arange(len(y))
237
  t_fcst = np.arange(len(y), len(y) + int(prediction_length))
238
 
239
  fig, ax = plt.subplots(figsize=(10, 4))
240
  ax.plot(t_hist, y, label="history")
241
  ax.plot(t_fcst, median, label="forecast (median)")
242
+ ax.fill_between(t_fcst, low, high, alpha=0.25, label=f"band [{float(q_low):.2f}, {float(q_high):.2f}]")
243
  ax.axvline(len(y) - 1, linestyle="--", linewidth=1)
244
  ax.set_title(source)
245
  ax.set_xlabel("t")
 
247
  ax.grid(True, alpha=0.3)
248
  ax.legend()
249
 
250
+ # Output table + CSV
 
 
251
  out_df = pd.DataFrame(
252
  {
253
  "t": t_fcst,
254
  "median": median,
255
+ f"q{float(q_low):.2f}": low,
256
+ f"q{float(q_high):.2f}": high,
257
  }
258
  )
259
 
 
261
  out_df.to_csv(out_path, index=False)
262
 
263
  info = {
264
+ "model_id": (model_id or MODEL_ID_DEFAULT),
265
  "device": device,
266
  "source": source,
267
+ "history_points": int(len(y)),
268
+ "prediction_length": int(prediction_length),
269
+ "requested_draws": int(num_draws),
270
+ "returned_draws": int(samples.shape[0]),
271
  }
272
 
273
  return fig, out_df, out_path, info
274
 
 
275
  # =========================
276
  # UI
277
  # =========================
278
  with gr.Blocks(title="Chronos-2 • HF Spaces Demo") as demo:
279
+ gr.Markdown("# ⏱️ Chronos-2 Forecast Demo (HF Spaces)\n\n"
280
+ "Supporta **Sample**, **Test CSV** (da cartella `data/`) e **Upload CSV**.")
281
 
282
  with gr.Row():
283
  input_mode = gr.Radio(
 
295
  with gr.Row():
296
  test_csv_name = gr.Dropdown(
297
  choices=available_test_csv(),
298
+ label="Test CSV disponibili (cartella data/)",
299
  )
300
  upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
301
+ csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value")
302
 
303
  with gr.Accordion("Sample data settings", open=False):
304
  n = gr.Slider(60, 600, 220, step=10, label="History length")
 
310
 
311
  with gr.Accordion("Forecast settings", open=True):
312
  prediction_length = gr.Slider(1, 180, 30, step=1, label="Prediction length")
313
+ # UI label stays "Num samples", internally treated as number of prediction draws if supported
314
+ num_draws = gr.Slider(1, 400, 200, step=10, label="Num samples (draws)")
315
  q_low = gr.Slider(0.01, 0.49, 0.10, step=0.01, label="Quantile low")
316
  q_high = gr.Slider(0.51, 0.99, 0.90, step=0.01, label="Quantile high")
317
 
 
336
  season_amp,
337
  noise,
338
  prediction_length,
339
+ num_draws,
340
  q_low,
341
  q_high,
342
  device_ui,