| import argparse |
| import logging |
| import csv |
| import random |
| import warnings |
| import time |
| import json |
| from pathlib import Path |
| from functools import partial |
| from typing import Dict, List, Tuple, Any, Optional |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import albumentations as A |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from sklearn.model_selection import train_test_split |
| from sklearn.metrics import ( |
| accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix |
| ) |
| from rasterio.errors import NotGeoreferencedWarning |
| from sentence_transformers import SentenceTransformer |
|
|
| |
| import terramind |
| from terratorch.tasks import ClassificationTask |
| from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY, TERRATORCH_DECODER_REGISTRY |
| from terramind.models.terramind_register import build_terrammind_vit |
|
|
| |
| from methane_text_datamodule import MethaneTextDataModule |
|
|
| |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| logging.getLogger("rasterio._env").setLevel(logging.ERROR) |
| warnings.simplefilter("ignore", NotGeoreferencedWarning) |
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
| |
| PRETRAINED_BANDS = { |
| 'untok_sen2l2a@224': [ |
| "COASTAL_AEROSOL", "BLUE", "GREEN", "RED", "RED_EDGE_1", "RED_EDGE_2", |
| "RED_EDGE_3", "NIR_BROAD", "NIR_NARROW", "WATER_VAPOR", "SWIR_1", "SWIR_2", |
| ] |
| } |
|
|
| def set_seed(seed: int = 42): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| def get_training_transforms() -> A.Compose: |
| return A.Compose([ |
| A.ElasticTransform(p=0.25), |
| A.RandomRotate90(p=0.5), |
| A.Flip(p=0.5), |
| A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5) |
| ]) |
|
|
| |
|
|
| |
| try: |
| EMBB_MODEL = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
| |
| |
| if torch.cuda.is_available(): |
| EMBB_MODEL = EMBB_MODEL.to("cuda") |
| except Exception as e: |
| logger.warning(f"Could not load SentenceTransformer: {e}") |
| EMBB_MODEL = None |
|
|
| class TerraMindWithText(nn.Module): |
| def __init__(self, terramind_kwargs: dict): |
| super().__init__() |
| self.terramind = build_terrammind_vit( |
| variant='terramind_v1_base', |
| encoder_depth=12, |
| dim=768, |
| num_heads=12, |
| mlp_ratio=4, |
| qkv_bias=False, |
| proj_bias=False, |
| mlp_bias=False, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| act_layer=nn.SiLU, |
| gated_mlp=True, |
| pretrained_bands=PRETRAINED_BANDS, |
| **terramind_kwargs |
| ) |
| self.out_channels = [768] * 12 |
| |
|
|
| def forward(self, x, captions): |
| vision_features = self.terramind(x) |
| |
| |
| |
| with torch.no_grad(): |
| captions_embed = EMBB_MODEL.encode(captions, convert_to_tensor=True, show_progress_bar=False) |
| |
| |
| if len(captions_embed.shape) == 3: |
| captions_embed = captions_embed.squeeze() |
| |
| return vision_features + [captions_embed] |
|
|
| @TERRATORCH_BACKBONE_REGISTRY.register |
| def terramind_v1_base_with_text(**kwargs): |
| return TerraMindWithText(terramind_kwargs=kwargs) |
|
|
| @TERRATORCH_DECODER_REGISTRY.register |
| class SimpleDecoder(nn.Module): |
| includes_head = True |
|
|
| def __init__(self, input_dim=768, num_classes=2, caption_dim=384): |
| super().__init__() |
| |
| dim = input_dim[0] if isinstance(input_dim, (list, tuple)) else input_dim |
| |
| self.image_conv = nn.Sequential( |
| nn.Conv2d(dim, 512, kernel_size=3, padding=1), |
| nn.BatchNorm2d(512), |
| nn.ReLU(inplace=True), |
| nn.Dropout2d(0.3), |
| nn.Conv2d(512, 256, kernel_size=3, padding=1), |
| nn.BatchNorm2d(256), |
| nn.ReLU(inplace=True), |
| nn.Dropout2d(0.3) |
| ) |
|
|
| self.caption_mlp = nn.Sequential( |
| nn.Linear(caption_dim, 512), |
| nn.ReLU(inplace=True), |
| nn.Dropout(0.3), |
| nn.Linear(512, 256), |
| nn.ReLU(inplace=True), |
| nn.Dropout(0.3) |
| ) |
|
|
| self.cross_attention = nn.MultiheadAttention( |
| embed_dim=256, num_heads=8, dropout=0.1, batch_first=True |
| ) |
|
|
| self.fusion_conv = nn.Sequential( |
| nn.Conv2d(512, 256, kernel_size=3, padding=1), |
| nn.BatchNorm2d(256), |
| nn.ReLU(inplace=True), |
| nn.Dropout2d(0.3), |
| nn.Conv2d(256, 128, kernel_size=3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(inplace=True), |
| nn.Dropout2d(0.3) |
| ) |
|
|
| self.conv_head = nn.Sequential( |
| nn.Conv2d(128, 64, kernel_size=3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(inplace=True), |
| nn.Dropout2d(0.3), |
| nn.Conv2d(64, 1, kernel_size=1) |
| ) |
|
|
| self.out_channels = 1 |
|
|
| def forward(self, features: list[torch.Tensor]) -> torch.Tensor: |
| |
| caption_embed = features[-1] |
| image_features = features[:12] |
| |
| |
| x = torch.stack(image_features, dim=1).mean(dim=1) |
| |
| B, N, C = x.shape |
| H = W = int(N ** 0.5) |
| |
| x = x.permute(0, 2, 1).view(B, C, H, W) |
| img_features = self.image_conv(x) |
| |
| |
| if caption_embed.dim() == 1: |
| caption_embed = caption_embed.unsqueeze(0) |
| |
| caption_features = self.caption_mlp(caption_embed) |
| |
| |
| caption_spatial = caption_features.unsqueeze(-1).unsqueeze(-1) |
| caption_spatial = caption_spatial.expand(B, -1, H, W) |
| |
| |
| fused_features = torch.cat([img_features, caption_spatial], dim=1) |
| fused = self.fusion_conv(fused_features) |
| |
| output = self.conv_head(fused) |
| return output |
|
|
| |
|
|
| class MetricTracker: |
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| self.all_targets = [] |
| self.all_predictions = [] |
| self.total_loss = 0.0 |
| self.steps = 0 |
|
|
| def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor): |
| self.total_loss += loss |
| self.steps += 1 |
| self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy()) |
| self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy()) |
|
|
| def compute(self) -> Dict[str, float]: |
| if not self.all_targets: |
| return {} |
| |
| tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel() |
| |
| return { |
| "Loss": self.total_loss / max(self.steps, 1), |
| "Accuracy": accuracy_score(self.all_targets, self.all_predictions), |
| "Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0, |
| "Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), |
| "F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), |
| "MCC": matthews_corrcoef(self.all_targets, self.all_predictions), |
| } |
|
|
| class MethaneTextTrainer: |
| def __init__(self, args: argparse.Namespace): |
| self.args = args |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}' |
| self.save_dir.mkdir(parents=True, exist_ok=True) |
| |
| self.model = self._init_model() |
| self.optimizer, self.scheduler = self._init_optimizer() |
| self.criterion = self.task.criterion |
| self.best_val_loss = float('inf') |
| |
| logger.info(f"Trainer initialized on device: {self.device}") |
|
|
| def _init_model(self) -> nn.Module: |
| model_args = dict( |
| backbone="terramind_v1_base_with_text", |
| backbone_pretrained=True, |
| backbone_modalities=["S2L2A"], |
| backbone_merge_method="mean", |
| num_classes=2, |
| head_dropout=0.3, |
| decoder="SimpleDecoder", |
| ) |
|
|
| self.task = ClassificationTask( |
| model_args=model_args, |
| model_factory="EncoderDecoderFactory", |
| loss="ce", |
| lr=self.args.lr, |
| ignore_index=-1, |
| optimizer="AdamW", |
| optimizer_hparams={"weight_decay": self.args.weight_decay}, |
| ) |
| self.task.configure_models() |
| self.task.configure_losses() |
| return self.task.model.to(self.device) |
|
|
| def _init_optimizer(self): |
| optimizer = optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True) |
| return optimizer, scheduler |
|
|
| def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]: |
| is_train = stage == "train" |
| self.model.train() if is_train else self.model.eval() |
| tracker = MetricTracker() |
| |
| with torch.set_grad_enabled(is_train): |
| pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False) |
| for batch in pbar: |
| |
| inputs = batch['S2L2A'].to(self.device) |
| captions = batch['caption'] |
| targets = batch['label'].to(self.device) |
|
|
| |
| |
| outputs = self.model(x={"S2L2A": inputs}, captions=captions) |
| probabilities = torch.softmax(outputs.output, dim=1) |
| loss = self.criterion(probabilities, targets) |
|
|
| if is_train: |
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
|
|
| tracker.update(loss.item(), targets, probabilities) |
| pbar.set_postfix(loss=f"{loss.item():.4f}") |
|
|
| return tracker.compute() |
|
|
| def log_to_csv(self, epoch: int, train_metrics: Dict, val_metrics: Dict): |
| csv_path = self.save_dir / 'train_val_metrics.csv' |
| headers = ['Epoch'] + [f'Train_{k}' for k in train_metrics.keys()] + [f'Val_{k}' for k in val_metrics.keys()] |
| |
| with open(csv_path, mode='a', newline='') as f: |
| writer = csv.writer(f) |
| if not csv_path.exists(): |
| writer.writerow(headers) |
| writer.writerow([epoch] + list(train_metrics.values()) + list(val_metrics.values())) |
|
|
| def fit(self, train_loader: DataLoader, val_loader: DataLoader): |
| logger.info(f"Starting training for {self.args.epochs} epochs...") |
| start_time = time.time() |
|
|
| for epoch in range(1, self.args.epochs + 1): |
| logger.info(f"Epoch {epoch}/{self.args.epochs}") |
| |
| train_metrics = self.run_epoch(train_loader, stage="train") |
| val_metrics = self.run_epoch(val_loader, stage="validate") |
| |
| self.scheduler.step(val_metrics['Loss']) |
| self.log_to_csv(epoch, train_metrics, val_metrics) |
| |
| logger.info(f"Train Loss: {train_metrics['Loss']:.4f} | Val Loss: {val_metrics['Loss']:.4f} | Val F1: {val_metrics['F1']:.4f}") |
|
|
| if val_metrics['Loss'] < self.best_val_loss: |
| self.best_val_loss = val_metrics['Loss'] |
| torch.save(self.model.state_dict(), self.save_dir / "best_model.pth") |
| logger.info(f"--> New best model saved") |
|
|
| torch.save(self.model.state_dict(), self.save_dir / "final_model.pth") |
| logger.info(f"Training finished in {time.time() - start_time:.2f}s") |
|
|
| |
|
|
| def read_captions(json_path: Path, captions_dict: Dict) -> Dict: |
| """Reads captions from JSON and populates dictionary.""" |
| if not json_path.exists(): |
| logger.warning(f"Caption file not found: {json_path}") |
| return captions_dict |
| |
| try: |
| with open(json_path, "r", encoding="utf-8") as file: |
| data = json.load(file) |
|
|
| for file_path_str, text_list in data.items(): |
| if text_list and isinstance(text_list, list) and text_list[0]: |
| text_content = text_list[0][0] |
| caption_start = text_content.find("CAPTION:") |
| if caption_start != -1: |
| caption = text_content[caption_start + len("CAPTION:"):].strip() |
| |
| |
| path_parts = file_path_str.replace("\\", "/").split("/") |
| if len(path_parts) >= 2: |
| last_directory = path_parts[-2] |
| captions_dict[last_directory] = caption |
| except Exception as e: |
| logger.error(f"Error reading captions {json_path}: {e}") |
| |
| return captions_dict |
|
|
| def get_paths_for_fold(excel_file: str, folds: List[int]) -> List[str]: |
| df = pd.read_excel(excel_file) |
| df_filtered = df[df['Fold'].isin(folds)] |
| return df_filtered['Filename'].tolist() |
|
|
| def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]: |
| |
| captions_dict = {} |
| captions_dict = read_captions(Path(args.methane_captions), captions_dict) |
| captions_dict = read_captions(Path(args.no_methane_captions), captions_dict) |
| logger.info(f"Loaded {len(captions_dict)} captions.") |
|
|
| |
| all_folds = range(1, args.num_folds + 1) |
| train_pool_folds = [f for f in all_folds if f != args.test_fold] |
| paths = get_paths_for_fold(args.excel_file, train_pool_folds) |
| |
| |
| train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed) |
| logger.info(f"Train: {len(train_paths)}, Val: {len(val_paths)}") |
|
|
| |
| datamodule = MethaneTextDataModule( |
| data_root=args.root_dir, |
| paths=paths, |
| captions=captions_dict, |
| train_transform=get_training_transforms(), |
| batch_size=args.batch_size, |
| ) |
| |
| |
| datamodule.paths = train_paths |
| datamodule.setup(stage="train") |
| train_loader = datamodule.train_dataloader() |
| |
| |
| datamodule.paths = val_paths |
| datamodule.setup(stage="validate") |
| val_loader = datamodule.val_dataloader() |
| |
| return train_loader, val_loader |
|
|
| |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Methane Text-Multimodal Training") |
| |
| |
| parser.add_argument('--root_dir', type=str, required=True, help='Root directory for images') |
| parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel') |
| parser.add_argument('--methane_captions', type=str, required=True, help='Path to Methane JSON captions') |
| parser.add_argument('--no_methane_captions', type=str, required=True, help='Path to No-Methane JSON captions') |
| parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Output directory') |
| |
| |
| parser.add_argument('--epochs', type=int, default=100) |
| parser.add_argument('--batch_size', type=int, default=4) |
| parser.add_argument('--lr', type=float, default=5e-5) |
| parser.add_argument('--weight_decay', type=float, default=0.05) |
| parser.add_argument('--num_folds', type=int, default=5) |
| parser.add_argument('--test_fold', type=int, default=2) |
| parser.add_argument('--seed', type=int, default=42) |
| |
| return parser.parse_args() |
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| set_seed(args.seed) |
| |
| train_loader, val_loader = get_data_loaders(args) |
| |
| trainer = MethaneTextTrainer(args) |
| trainer.fit(train_loader, val_loader) |