File size: 18,697 Bytes
12f45c0
 
069627f
12f45c0
27ce714
e1f5207
12f45c0
 
 
 
 
 
c85b8a5
 
 
 
 
 
12f45c0
 
2c1d599
12f45c0
 
0b4284f
12f45c0
 
 
 
 
 
 
069627f
 
 
 
12f45c0
 
 
 
 
0b4284f
12f45c0
c8d76da
12f45c0
 
 
 
 
0b4284f
12f45c0
c8d76da
12f45c0
 
 
 
 
 
 
 
 
 
 
0b4284f
12f45c0
 
 
 
 
 
 
 
 
 
0b4284f
12f45c0
e5de9d8
0b4284f
12f45c0
 
572e6a8
 
 
12f45c0
 
 
e5de9d8
3254242
e5de9d8
 
 
 
3254242
 
 
 
12f45c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3254242
 
 
 
 
 
12f45c0
 
 
 
 
 
 
d080539
12f45c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
069627f
27ce714
 
069627f
27ce714
12f45c0
dc9b9db
b8daa7e
 
e5de9d8
b8daa7e
 
 
e5de9d8
 
b8daa7e
0b4284f
 
 
12f45c0
0b4284f
12f45c0
 
 
 
 
0b4284f
069627f
2d135b5
0b4284f
3ac5032
 
572e6a8
 
 
 
069627f
572e6a8
 
 
7a9aff9
40b0931
3ac5032
572e6a8
2d135b5
3ac5032
 
0b4284f
3ac5032
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fa9b28
 
0b4284f
3ac5032
7f2c237
 
 
 
3ac5032
 
 
 
0b4284f
 
3ac5032
0b4284f
 
 
 
3ac5032
 
 
 
 
 
 
0b4284f
 
 
 
 
069627f
0b4284f
 
 
 
 
 
 
 
dc9b9db
12f45c0
 
 
 
 
 
 
 
 
 
 
 
3ac5032
 
12f45c0
 
 
 
 
 
 
 
 
 
 
67808ce
12f45c0
 
 
 
 
 
 
 
 
7d5b63d
 
 
12f45c0
 
 
 
 
 
3ac5032
 
 
7d5b63d
 
 
 
 
 
 
 
 
 
 
 
 
12f45c0
 
 
7d5b63d
 
12f45c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d5b63d
 
 
 
cd3a36f
 
 
 
7d5b63d
 
cd3a36f
f197da0
 
7d5b63d
cd3a36f
 
 
 
7d5b63d
3f32d3a
12f45c0
 
 
 
7d5b63d
 
 
 
 
 
 
12f45c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
#!/usr/bin/env python3
"""
Chronos-2 Inference Pipeline with Past-Only Covariate Masking
Standalone inference script for HuggingFace Space deployment.
Uses predict_df() API with ALL 2,514 features leveraging Chronos-2's mask-based attention.
FORCE REBUILD: v1.6.0 - Extended context window (2,160 hours = 90 days) optimized for 96GB VRAM
"""

import os
import time
from typing import List, Dict, Optional
from datetime import datetime, timedelta

# CRITICAL: Set PyTorch memory allocator config BEFORE importing torch
# This prevents memory fragmentation issues that cause OOM even with sufficient free memory
# See: https://pytorch.org/docs/stable/notes/cuda.html#environment-variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import polars as pl
import pandas as pd
import numpy as np
import torch
from datasets import load_dataset
from chronos import Chronos2Pipeline

from .dynamic_forecast import DynamicForecast
from .feature_availability import FeatureAvailability


class ChronosInferencePipeline:
    """
    Production inference pipeline for Chronos-2 zero-shot forecasting WITH PAST-ONLY MASKING.
    Uses predict_df() API with ALL 3,043 features (known-future + past-only covariates).
    Past-only covariates (CNEC, volatility, historical flows) are masked in future → model
    learns cross-feature correlations from historical context via attention mechanism.
    Designed for deployment as API endpoint on HuggingFace Spaces.
    """

    def __init__(
        self,
        model_name: str = "amazon/chronos-2",
        device: str = "cuda",
        dtype: str = "bfloat16"
    ):
        """
        Initialize inference pipeline.

        Args:
            model_name: HuggingFace model identifier (chronos-2 supports covariates)
            device: Device for inference ('cuda' or 'cpu')
            dtype: Data type for model weights (bfloat16 for memory efficiency)
        """
        self.model_name = model_name
        self.device = device
        self.dtype = dtype

        # Model loaded on first inference (lazy loading)
        self._pipeline = None
        self._dataset = None
        self._borders = None

    def _load_model(self):
        """Load Chronos-2 model (cached after first call)"""
        if self._pipeline is None:
            print(f"Loading {self.model_name}...")
            start_time = time.time()

            dtype_map = {
                "bfloat16": torch.bfloat16,
                "float16": torch.float16,
                "float32": torch.float32
            }

            self._pipeline = Chronos2Pipeline.from_pretrained(
                self.model_name,
                device_map="auto",  # Auto-distribute across all available GPUs
                torch_dtype=dtype_map.get(self.dtype, torch.float32)
            )

            # Set model to evaluation mode (disables dropout, etc.)
            self._pipeline.model.eval()

            print(f"Model loaded in {time.time() - start_time:.1f}s")
            print(f"  Device: {next(self._pipeline.model.parameters()).device}")

            # GPU detection and memory profiling diagnostics
            if torch.cuda.is_available():
                gpu_count = torch.cuda.device_count()
                total_vram = sum(torch.cuda.get_device_properties(i).total_memory for i in range(gpu_count))
                print(f"  [GPU] Detected {gpu_count} GPU(s)")
                print(f"  [GPU] Total VRAM: {total_vram/1e9:.1f} GB")
                print(f"  [MEMORY] After model load:")
                print(f"    GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
                print(f"    GPU memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB")

        return self._pipeline

    def _load_dataset(self):
        """Load dataset from HuggingFace (cached after first call)"""
        if self._dataset is None:
            print("Loading dataset from HuggingFace...")
            start_time = time.time()

            hf_token = os.getenv("HF_TOKEN")
            dataset = load_dataset(
                "evgueni-p/fbmc-features-24month",
                split="train",
                token=hf_token
            )

            # Convert to Polars
            self._dataset = pl.from_arrow(dataset.data.table)

            # Extract available borders
            target_cols = [col for col in self._dataset.columns if col.startswith('target_border_')]
            self._borders = [col.replace('target_border_', '') for col in target_cols]

            print(f"Dataset loaded in {time.time() - start_time:.1f}s")
            print(f"  Shape: {self._dataset.shape}")
            print(f"  Borders: {len(self._borders)}")

            # Memory profiling diagnostics
            if torch.cuda.is_available():
                print(f"  [MEMORY] After dataset load:")
                print(f"    GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
                print(f"    GPU memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB")

        return self._dataset, self._borders

    def run_forecast(
        self,
        run_date: str,
        borders: Optional[List[str]] = None,
        forecast_days: int = 7,
        context_hours: int = 1125,  # 1,125 hours = 46.9 days (1.5 months, fits A100-80GB)
        num_samples: int = 20
    ) -> Dict:
        """
        Run zero-shot forecast for specified borders.

        Args:
            run_date: Forecast run date (YYYY-MM-DD format)
            borders: List of borders to forecast (None = all borders)
            forecast_days: Forecast horizon in days (7 or 14)
            context_hours: Historical context window
            num_samples: Number of probabilistic samples

        Returns:
            Dictionary with forecast results and metadata
        """
        # Load model and dataset (cached)
        pipeline = self._load_model()
        df, all_borders = self._load_dataset()

        # Parse run date
        run_datetime = datetime.strptime(run_date, "%Y-%m-%d")
        run_datetime = run_datetime.replace(hour=23, minute=0)

        # Determine borders to forecast
        forecast_borders = borders if borders else all_borders
        prediction_hours = forecast_days * 24

        print(f"\nForecast configuration:")
        print(f"  Run date: {run_datetime}")
        print(f"  Borders: {len(forecast_borders)}")
        print(f"  Forecast horizon: {forecast_days} days ({prediction_hours} hours)")
        print(f"  Context window: {context_hours} hours")

        # Initialize dynamic forecast system
        forecaster = DynamicForecast(
            dataset=df,
            context_hours=context_hours,
            forecast_hours=prediction_hours
        )

        # Run forecasts for each border
        results = {
            'run_date': run_date,
            'forecast_days': forecast_days,
            'borders': {},
            'metadata': {
                'model': self.model_name,
                'device': self.device,
                'num_samples': num_samples,
                'context_hours': context_hours
            }
        }

        total_start = time.time()

        # PER-BORDER INFERENCE WITH PAST-ONLY COVARIATE MASKING
        # Using predict_df() API with ALL 2,514 features (known-future + past-only masked)
        print(f"\n[PAST-ONLY MASKING] Running inference for {len(forecast_borders)} borders with 2,514 features...")
        print(f"  Known-future: weather, generation, load forecasts (615 features)")
        print(f"  Past-only masked: CNEC outages, volatility, historical flows (1,899 features)")

        for i, border in enumerate(forecast_borders, 1):
            # Clear GPU cache BEFORE each border to prevent memory accumulation
            # This releases tensors from previous border (no-op on first iteration)
            # Does NOT affect model weights (120M params stay loaded)
            # Does NOT affect forecast accuracy (each border is independent)
            if i > 1:  # Skip on first border (clean GPU state)
                torch.cuda.empty_cache()
                import gc
                gc.collect()  # Force Python garbage collector to free tensors

            border_start = time.time()
            print(f"\n  [{i}/{len(forecast_borders)}] {border}...", flush=True)

            try:
                # Extract data WITH covariates
                context_data, future_data = forecaster.prepare_forecast_data(
                    run_date=run_datetime,
                    border=border
                )

                print(f"    Context shape: {context_data.shape}, Future shape: {future_data.shape}", flush=True)
                print(f"    Using {len(future_data.columns)-2} features (known-future + past-only masked)", flush=True)

                # Run covariate-informed inference using DataFrame API
                # Note: predict_df() returns quantiles directly
                # Request 9 quantiles to capture learned uncertainty and tail events
                # Use torch.inference_mode() to disable gradient tracking (saves ~2-5 GB VRAM)
                with torch.inference_mode():
                    forecasts_df = pipeline.predict_df(
                        context_data,  # Historical data with ALL features
                        future_df=future_data,  # All 3,043 features (past-only masked)
                        prediction_length=prediction_hours,
                        id_column='border',
                        timestamp_column='timestamp',
                        target='target',
                        batch_size=32,  # Reduced from 64 (41.57GB -> 20.79GB attention tensor to fit single GPU)
                        quantile_levels=[0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99]  # 9 quantiles for volatility
                    )

                # Extract all 9 quantiles from predict_df() output
                # predict_df() returns quantiles directly as string columns
                if isinstance(forecasts_df, pd.DataFrame):
                    # Expected columns: '0.01', '0.05', '0.1', '0.25', '0.5', '0.75', '0.9', '0.95', '0.99'
                    quantile_cols = ['0.01', '0.05', '0.1', '0.25', '0.5', '0.75', '0.9', '0.95', '0.99']

                    # Extract all quantiles
                    quantiles = {}
                    for q in quantile_cols:
                        if q in forecasts_df.columns:
                            quantiles[q] = forecasts_df[q].values
                        else:
                            # Fallback if quantile missing
                            if '0.5' in forecasts_df.columns:
                                quantiles[q] = forecasts_df['0.5'].values  # Use median as fallback
                            elif 'predictions' in forecasts_df.columns:
                                quantiles[q] = forecasts_df['predictions'].values
                            else:
                                raise ValueError(f"Missing quantile {q} and no fallback available. Columns: {forecasts_df.columns.tolist()}")

                    # Backward compatibility: still extract median, q10, q90
                    median = quantiles['0.5']
                    q10 = quantiles['0.1']
                    q90 = quantiles['0.9']
                else:
                    raise TypeError(f"Expected DataFrame from predict_df(), got {type(forecasts_df)}")

                # Round all quantiles to nearest integer (capacity values are always whole MW)
                median = np.round(median).astype(int)
                q10 = np.round(q10).astype(int)
                q90 = np.round(q90).astype(int)

                # Round all other quantiles
                for q_key in quantiles:
                    quantiles[q_key] = np.round(quantiles[q_key]).astype(int)

                inference_time = time.time() - border_start

                # Store results (backward compatible + all quantiles)
                results['borders'][border] = {
                    'median': median.tolist(),
                    'q10': q10.tolist(),
                    'q90': q90.tolist(),
                    # Add all 9 quantiles for adaptive selection
                    'q01': quantiles['0.01'].tolist(),
                    'q05': quantiles['0.05'].tolist(),
                    'q25': quantiles['0.25'].tolist(),
                    'q75': quantiles['0.75'].tolist(),
                    'q95': quantiles['0.95'].tolist(),
                    'q99': quantiles['0.99'].tolist(),
                    'inference_time_s': inference_time,
                    'used_covariates': True,
                    'num_features': len(future_data.columns) - 2  # Exclude border and timestamp
                }

                print(f"    [OK] Complete in {inference_time:.1f}s ({len(future_data.columns)-2} features with past-only masking)", flush=True)

            except Exception as e:
                import traceback
                error_msg = f"{type(e).__name__}: {str(e)}"
                traceback_str = traceback.format_exc()
                print(f"    [ERROR] {error_msg}", flush=True)
                print(f"Traceback:\n{traceback_str}", flush=True)
                results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}

        # Add summary metadata
        results['metadata']['total_time_s'] = time.time() - total_start
        results['metadata']['successful_borders'] = sum(
            1 for b in results['borders'].values() if 'error' not in b
        )

        print(f"\n{'='*60}")
        print(f"FORECAST COMPLETE")
        print(f"{'='*60}")
        print(f"Total time: {results['metadata']['total_time_s']:.1f}s")
        print(f"Successful: {results['metadata']['successful_borders']}/{len(forecast_borders)} borders")

        return results


    def export_to_parquet(self, results: Dict, output_path: str):
        """
        Export forecast results to parquet format.

        Args:
            results: Forecast results from run_forecast()
            output_path: Path to save parquet file
        """
        # Create forecast timestamps
        run_datetime = datetime.strptime(results['run_date'], "%Y-%m-%d")
        forecast_start = run_datetime + timedelta(days=1)  # Next day at midnight, not +1 hour
        forecast_hours = results['forecast_days'] * 24

        timestamps = [
            forecast_start + timedelta(hours=h)
            for h in range(forecast_hours)
        ]

        # Build DataFrame
        data = {'timestamp': timestamps}
        
        successful_borders = []
        failed_borders = []

        for border, forecast_data in results['borders'].items():
            if 'error' not in forecast_data:
                data[f'{border}_median'] = forecast_data['median']
                data[f'{border}_q10'] = forecast_data['q10']
                data[f'{border}_q90'] = forecast_data['q90']
                # Add adaptive forecast if available (learned uncertainty-based selection)
                if 'adaptive' in forecast_data:
                    data[f'{border}_adaptive'] = forecast_data['adaptive']
                successful_borders.append(border)
            else:
                failed_borders.append((border, forecast_data['error']))

        # Log results
        print(f"[EXPORT] Forecast export summary:", flush=True)
        print(f"  Successful: {len(successful_borders)} borders", flush=True)
        print(f"  Failed: {len(failed_borders)} borders", flush=True)
        if failed_borders:
            print(f"[EXPORT] Errors:", flush=True)
            for border, error in failed_borders:
                print(f"  {border}: {error}", flush=True)
        
        df = pl.DataFrame(data)
        df.write_parquet(output_path)

        print(f"[EXPORT] Exported to: {output_path}", flush=True)
        print(f"[EXPORT] Shape: {df.shape}, Columns: {len(df.columns)}", flush=True)

        return output_path


# Convenience function for API usage
def run_inference(
    run_date: str,
    forecast_type: str = "smoke_test",
    borders: Optional[List[str]] = None,
    output_dir: str = "/tmp"
) -> str:
    """
    Run forecast and return path to results file.

    Args:
        run_date: Forecast run date (YYYY-MM-DD)
        forecast_type: 'smoke_test' (7 days, 1 border) or 'full_14day' (14 days, all borders)
        borders: Specific borders to forecast (None = use forecast_type defaults)
        output_dir: Directory to save results

    Returns:
        Path to forecast results parquet file
    """
    # Initialize pipeline
    pipeline = ChronosInferencePipeline()

    # Configure based on forecast type
    if forecast_type == "smoke_test":
        forecast_days = 7
        if borders is None:
            # Load just to get first border
            _, all_borders = pipeline._load_dataset()
            borders = [all_borders[0]]
    else:  # full_14day
        forecast_days = 14
        # borders = None means all borders

    # Run forecast
    results = pipeline.run_forecast(
        run_date=run_date,
        borders=borders,
        forecast_days=forecast_days
    )

    # Write debug file
    debug_filename = f"debug_{run_date}_{forecast_type}.txt"
    debug_path = os.path.join(output_dir, debug_filename)
    with open(debug_path, 'w') as f:
        f.write(f"Results summary:\n")
        f.write(f"  Run date: {results['run_date']}\n")
        f.write(f"  Forecast days: {results['forecast_days']}\n")
        f.write(f"  Borders in results: {list(results['borders'].keys())}\n\n")
        for border, data in results['borders'].items():
            if 'error' in data:
                f.write(f"  {border}: ERROR - {data['error']}\n")
                if 'traceback' in data:
                    f.write(f"\nFull Traceback:\n{data['traceback']}\n")
            else:
                f.write(f"  {border}: OK\n")
                f.write(f"    median count: {len(data.get('median', []))}\n")
                f.write(f"    q10 count: {len(data.get('q10', []))}\n")
                f.write(f"    q90 count: {len(data.get('q90', []))}\n")
    print(f"Debug file written to: {debug_path}", flush=True)
    
    # Export to parquet
    output_filename = f"forecast_{run_date}_{forecast_type}.parquet"
    output_path = os.path.join(output_dir, output_filename)
    pipeline.export_to_parquet(results, output_path)
    
    # Check if forecast has data, if not return debug file
    successful_count = sum(1 for data in results['borders'].values() if 'error' not in data)
    if successful_count == 0:
        print(f"[WARNING] No successful forecasts! Returning debug file instead.", flush=True)
        return debug_path
    
    return output_path