| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from argparse import Namespace |
| | from typing import NamedTuple, Optional |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class AdaptorInput(NamedTuple): |
| | images: torch.Tensor |
| | summary: torch.Tensor |
| | features: torch.Tensor |
| | feature_fmt: str |
| | patch_size: int |
| |
|
| |
|
| | class RadioOutput(NamedTuple): |
| | summary: torch.Tensor |
| | features: torch.Tensor |
| |
|
| | def to(self, *args, **kwargs): |
| | return RadioOutput( |
| | self.summary.to(*args, **kwargs) if self.summary is not None else None, |
| | self.features.to(*args, **kwargs) if self.features is not None else None, |
| | ) |
| |
|
| |
|
| | class AdaptorModuleBase(nn.Module): |
| | def __init__( |
| | self, |
| | requires_summary_and_spatial: bool, |
| | handles_summary_and_spatial: bool = False |
| | ) -> None: |
| | super().__init__() |
| | self.requires_summary_and_spatial = requires_summary_and_spatial |
| | self.handles_summary_and_spatial = handles_summary_and_spatial |
| |
|
| | assert not handles_summary_and_spatial or requires_summary_and_spatial, "If handles summary and spatial, must require it too!" |
| |
|
| |
|
| | class AdaptorBase(nn.Module): |
| | def forward(self, input: AdaptorInput) -> RadioOutput: |
| | raise NotImplementedError("Subclasses must implement this!") |
| |
|