| |
| |
| |
| |
| |
| |
| |
| from argparse import Namespace |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput |
| from .adaptor_mlp import create_mlp_from_state |
|
|
|
|
| class GenericAdaptor(AdaptorBase): |
| def __init__(self, main_config: Namespace, adaptor_config, state): |
| super().__init__() |
|
|
| self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.') |
| self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.') |
|
|
| def forward(self, input: AdaptorInput) -> RadioOutput: |
| summary = self.head_mlp(input.summary) |
| feat = self.feat_mlp(input.features) |
|
|
| return RadioOutput(summary, feat) |
|
|