| |
| |
| |
| |
| |
| |
| |
| from argparse import Namespace |
| from typing import NamedTuple, Optional |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
|
|
| class AdaptorInput(NamedTuple): |
| images: torch.Tensor |
| summary: torch.Tensor |
| features: torch.Tensor |
| feature_fmt: str |
| patch_size: int |
|
|
|
|
| class RadioOutput(NamedTuple): |
| summary: torch.Tensor |
| features: torch.Tensor |
|
|
| def to(self, *args, **kwargs): |
| return RadioOutput( |
| self.summary.to(*args, **kwargs) if self.summary is not None else None, |
| self.features.to(*args, **kwargs) if self.features is not None else None, |
| ) |
|
|
|
|
| class AdaptorBase(nn.Module): |
| def forward(self, input: AdaptorInput) -> RadioOutput: |
| raise NotImplementedError("Subclasses must implement this!") |
|
|