| | import os |
| | import torch |
| |
|
| | |
| | from models.LSTM import LSTM |
| | from models.LSTNet import LSTNet |
| | from models.Transformer import Transformer |
| | from models.Autoformer import Autoformer |
| | from models.Informer import Informer |
| | from models.PatchTST import PatchTST |
| | from models.TimesNet import TimesNet |
| | from models.TimesFM import TimesFM |
| |
|
| | |
| | from model_kwargs import * |
| |
|
| | |
| | |
| | lookback, lookahead, heterogeneity = 512, 48, 'HET' |
| |
|
| | if __name__ == "__main__": |
| |
|
| | models = [LSTM, LSTNet, Transformer, Autoformer, Informer, PatchTST, TimesNet, TimesFM] |
| | kw_fns = [lstm_kwargs, lstnet_kwargs, transformer_kwargs, autoformer_kwargs, informer_kwargs, patchtst_kwargs, timesnet_kwargs, timesfm_kwargs] |
| |
|
| | |
| | for model_class, kw_fn in zip(models,kw_fns): |
| | |
| | model = model_class(**kw_fn(lookback = lookback, lookahead = lookahead)) |
| | |
| | result = model.load_state_dict(torch.load(os.path.join(*[os.getcwd(),'weights',f'{model_class.__name__}_L_{lookback}_T_{lookahead}_{heterogeneity}.pth']),map_location='cpu')) |
| | |
| | print(f"Loading weight for model {model_class.__name__}, lookback {lookback}, lookahead {lookahead}, heterogeneity {heterogeneity}, and the result was: {result}.") |
| |
|