| import os |
| import rasterio |
| import torch |
| from torchgeo.datasets import NonGeoDataset |
| from torch.utils.data import DataLoader |
| from torchgeo.datamodules import NonGeoDataModule |
| from methane_classification_dataset import MethaneClassificationDataset |
|
|
| class MethaneClassificationDataModule(NonGeoDataModule): |
| """ |
| A DataModule for handling MethaneClassificationDataset |
| """ |
|
|
| def __init__( |
| self, |
| data_root: str, |
| excel_file: str, |
| paths: list, |
| batch_size: int = 8, |
| num_workers: int = 0, |
| train_transform: callable = None, |
| val_transform: callable = None, |
| test_transform: callable = None, |
| **kwargs |
| ): |
| super().__init__(MethaneClassificationDataset, batch_size, num_workers, **kwargs) |
|
|
| self.data_root = data_root |
| self.excel_file = excel_file |
| self.paths = paths |
| self.train_transform = train_transform |
| self.val_transform = val_transform |
| self.test_transform = test_transform |
|
|
| def setup(self, stage: str = None): |
| if stage in ("fit", "train"): |
| self.train_dataset = MethaneClassificationDataset( |
| root_dir=self.data_root, |
| excel_file=self.excel_file, |
| paths=self.paths, |
| transform=self.train_transform, |
| ) |
| if stage in ("fit", "validate", "val"): |
| self.val_dataset = MethaneClassificationDataset( |
| root_dir=self.data_root, |
| excel_file=self.excel_file, |
| paths=self.paths, |
| transform=self.val_transform, |
| ) |
| if stage in ("test", "predict"): |
| self.test_dataset = MethaneClassificationDataset( |
| root_dir=self.data_root, |
| excel_file=self.excel_file, |
| paths=self.paths, |
| transform=self.test_transform, |
| ) |
|
|
| def train_dataloader(self): |
| return DataLoader( |
| self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True |
| ) |
|
|
| def val_dataloader(self): |
| return DataLoader( |
| self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True |
| ) |
|
|
| def test_dataloader(self): |
| return DataLoader( |
| self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True |
| ) |
|
|