diff --git a/configs/release/check_middle_sensitivity.py b/configs/release/check_middle_sensitivity.py index 624df6d..4bb8066 100644 --- a/configs/release/check_middle_sensitivity.py +++ b/configs/release/check_middle_sensitivity.py @@ -86,6 +86,7 @@ def main(args: argparse.Namespace): for batch in tqdm(dl, desc=f"Processing {country} batches", leave=False): images = batch["image"] + masks = batch["mask"] if model_type in ["fcsiamdiff", "fcsiamconc", "fcsiamavg"]: images = rearrange(images, "b (t c) h w -> b t c h w", t=2) @@ -141,6 +142,16 @@ def main(args: argparse.Namespace): ) consensus_score = consensus.mean() + # Extract central region from ground truth mask + gt_central = masks[i, padding : padding + overlap_size, padding : padding + overlap_size].numpy() + + # Compute accuracy for each corner's prediction + acc1 = (hard_output[0] == gt_central).mean() + acc2 = (hard_output[1] == gt_central).mean() + acc3 = (hard_output[2] == gt_central).mean() + acc4 = (hard_output[3] == gt_central).mean() + mean_accuracy = (acc1 + acc2 + acc3 + acc4) / 4 + # Store result for this patch all_results.append( { @@ -148,6 +159,11 @@ def main(args: argparse.Namespace): "country": country, "patch_idx": patch_idx, "consensus_score": consensus_score, + "accuracy_corner1": acc1, + "accuracy_corner2": acc2, + "accuracy_corner3": acc3, + "accuracy_corner4": acc4, + "mean_accuracy": mean_accuracy, "split": args.split, } ) @@ -167,10 +183,13 @@ def main(args: argparse.Namespace): print(f"Min consensus: {df['consensus_score'].min():.4f}") print(f"Max consensus: {df['consensus_score'].max():.4f}") print(f"Median consensus: {df['consensus_score'].median():.4f}") + print( + f"Mean accuracy: {df['mean_accuracy'].mean():.4f} +/- {df['mean_accuracy'].std():.4f}" + ) print("\nPer-country statistics:") country_stats = ( - df.groupby("country")["consensus_score"].agg(["count", "mean", "std"]).round(4) + df.groupby("country")[["consensus_score", "mean_accuracy"]].agg(["count", "mean", "std"]).round(4) ) print(country_stats) diff --git a/ftw_tools/inference/inference.py b/ftw_tools/inference/inference.py index a0866ff..653c122 100644 --- a/ftw_tools/inference/inference.py +++ b/ftw_tools/inference/inference.py @@ -127,8 +127,9 @@ def run( compute_consensus: bool = False, preprocess_fn: Callable = default_preprocess, ): - if save_scores and compute_consensus: - raise ValueError("save_scores and compute_consensus are mutually exclusive.") + assert not (save_scores and compute_consensus), ( + "save_scores and compute_consensus are mutually exclusive." + ) device, transform, input_shape, patch_size, stride, padding = setup_inference( input, out, gpu, patch_size, padding, overwrite, mps_mode diff --git a/ftw_tools/inference/models.py b/ftw_tools/inference/models.py index ba4e13a..ce9dc67 100644 --- a/ftw_tools/inference/models.py +++ b/ftw_tools/inference/models.py @@ -125,6 +125,9 @@ def __init__( max_detections: int = 100, iou_threshold: float = 0.3, conf_threshold: float = 0.05, + percentile_low: float | None = 0.01, + percentile_high: float | None = 0.99, + norm_constant: float | None = None, device: str = "cuda" if torch.cuda.is_available() else "cpu", ) -> None: """Initialize the DelineateAnything model. @@ -136,9 +139,26 @@ def __init__( max_detections: Maximum number of detections per image. iou_threshold: Intersection over Union threshold for filtering predictions. conf_threshold: Confidence threshold for filtering predictions. + percentile_low: Lower percentile for per-channel normalization (default: 0.01). + Mutually exclusive with norm_constant. + percentile_high: Upper percentile for per-channel normalization (default: 0.99). + Mutually exclusive with norm_constant. + norm_constant: If provided, divide by this value and clip to [0, 1] instead of + using percentile normalization. Mutually exclusive with percentile_low/percentile_high. device: Device to run the model on, either "cuda" or "cpu". + + Raises: + ValueError: If both norm_constant and percentile options are specified. """ super().__init__() + + # Validate mutually exclusive normalization options + if norm_constant is not None and (percentile_low is not None or percentile_high is not None): + raise ValueError( + "norm_constant is mutually exclusive with percentile_low/percentile_high. " + "Set percentile_low=None and percentile_high=None when using norm_constant." + ) + self.patch_size = ( (patch_size, patch_size) if isinstance(patch_size, int) else patch_size ) @@ -149,6 +169,9 @@ def __init__( self.max_detections = max_detections self.iou_threshold = iou_threshold self.conf_threshold = conf_threshold + self.percentile_low = percentile_low + self.percentile_high = percentile_high + self.norm_constant = norm_constant self.device = device from ultralytics import YOLO @@ -156,14 +179,51 @@ def __init__( self.model.eval() self.model.fuse() self.transforms = nn.Sequential( - T.Lambda(lambda x: x.unsqueeze(dim=0) if x.ndim == 3 else x), - T.Lambda(lambda x: x[:, :3, ...]), - T.Lambda(lambda x: x / 3000.0), - T.Lambda(lambda x: x.clip(0.0, 1.0)), T.Resize(self.image_size, interpolation=T.InterpolationMode.BILINEAR), T.ConvertImageDtype(torch.float32), ).to(device) + def _normalize(self, x: torch.Tensor) -> torch.Tensor: + """Apply normalization based on configured method. + + Uses either constant normalization (if norm_constant is set) or + per-channel percentile normalization (if percentile_low/high are set). + + Args: + x: Input tensor of shape (B, C, H, W). + + Returns: + Normalized tensor clipped to [0, 1]. + """ + if self.norm_constant is not None: + # Simple constant normalization: divide and clip + return (x / self.norm_constant).clip(0.0, 1.0) + + # Per-channel percentile normalization + # At this point percentile_low and percentile_high are guaranteed to be non-None + assert self.percentile_low is not None and self.percentile_high is not None + + # x shape: (B, C, H, W) + B, C, H, W = x.shape + # Flatten spatial dimensions for percentile computation + x_flat = x.view(B, C, -1) # (B, C, H*W) + + # Compute percentiles per channel per batch + p_low = torch.quantile(x_flat.float(), self.percentile_low, dim=2, keepdim=True) # (B, C, 1) + p_high = torch.quantile(x_flat.float(), self.percentile_high, dim=2, keepdim=True) # (B, C, 1) + + # Reshape for broadcasting + p_low = p_low.view(B, C, 1, 1) + p_high = p_high.view(B, C, 1, 1) + + # Avoid division by zero + denom = p_high - p_low + denom = torch.where(denom == 0, torch.ones_like(denom), denom) + + # Normalize and clip + x_norm = (x - p_low) / denom + return x_norm.clip(0.0, 1.0) + @staticmethod def polygonize( result: "ultralytics.engine.results.Results", @@ -211,12 +271,24 @@ def __call__( """Forward pass through the model. Args: - image: The input image tensor, expected to be in the format (B, C, H, W). + image: The input image tensor, expected to be in the format (B, C, H, W) or (C, H, W). Returns: A list of results containing the model predictions. """ - image = self.transforms(image.to(self.device)) + # Add batch dimension if needed + if image.ndim == 3: + image = image.unsqueeze(0) + + # Select first 3 channels (RGB) + image = image[:, :3, ...] + + # Apply normalization (constant or percentile-based) + image = self._normalize(image.to(self.device)) + + # Apply remaining transforms (resize, dtype conversion) + image = self.transforms(image) + results = self.model.predict( image, conf=self.conf_threshold, diff --git a/ftw_tools/training/eval.py b/ftw_tools/training/eval.py index 877f077..3263d5e 100644 --- a/ftw_tools/training/eval.py +++ b/ftw_tools/training/eval.py @@ -1,7 +1,7 @@ import os import time from contextlib import contextmanager -from typing import Sequence +from typing import Literal, Sequence import kornia.augmentation as K import numpy as np @@ -18,6 +18,7 @@ from ftw_tools.training.datasets import FTW from ftw_tools.training.metrics import get_object_level_metrics from ftw_tools.training.trainers import CustomSemanticSegmentationTask +from ftw_tools.inference.models import DelineateAnything def expand_countries(countries: Sequence[str]) -> list[str]: @@ -534,3 +535,247 @@ def test( f.write( f"{model_path},{country_str},{pixel_level_iou},{pixel_level_precision},{pixel_level_recall},{object_precision},{object_recall},{object_f1}\n" ) + + +def test_delineate_anything( + dir: str, + gpu: int, + countries: Sequence[str], + iou_threshold: float, + out: str | None = None, + model_variant: Literal["DelineateAnything-S", "DelineateAnything"] = "DelineateAnything-S", + percentile_low: float | None = 0.02, + percentile_high: float | None = 0.98, + norm_constant: float | None = None, + resize_factor: int = 2, + patch_size: int = 256, + max_detections: int = 100, + conf_threshold: float = 0.05, + model_iou_threshold: float = 0.3, + num_workers: int = 4, + batch_size: int = 16, + use_val_set: bool = False, +): + """Test DelineateAnything model on FTW dataset. + + Args: + dir: Root directory of FTW dataset. + gpu: GPU device index (-1 for CPU). + countries: List of countries to test on. + iou_threshold: IoU threshold for object-level metrics. + out: Output CSV file path. + model_variant: "DelineateAnything-S" or "DelineateAnything". + percentile_low: Lower percentile for normalization. Mutually exclusive with norm_constant. + percentile_high: Upper percentile for normalization. Mutually exclusive with norm_constant. + norm_constant: If provided, divide by this value and clip to [0, 1] instead of + using percentile normalization. Mutually exclusive with percentile_low/percentile_high. + resize_factor: Factor to resize input images. + patch_size: Input patch size. + max_detections: Maximum detections per image. + conf_threshold: Confidence threshold for detections. + model_iou_threshold: IoU threshold for NMS in model. + num_workers: Number of dataloader workers. + batch_size: Batch size for inference. + use_val_set: If True, use validation set instead of test set. + + Raises: + ValueError: If both norm_constant and percentile options are specified. + """ + from scipy.ndimage import binary_erosion + + # Validate mutually exclusive normalization options + if norm_constant is not None and (percentile_low is not None or percentile_high is not None): + raise ValueError( + "norm_constant is mutually exclusive with percentile_low/percentile_high. " + "Set percentile_low=None and percentile_high=None when using norm_constant." + ) + + target_split = "val" if use_val_set else "test" + print(f"Running DelineateAnything test on the {target_split} set") + + if gpu is None: + gpu = -1 + + countries = expand_countries(countries) + + if torch.cuda.is_available() and gpu >= 0: + device = torch.device(f"cuda:{gpu}") + else: + device = torch.device("cpu") + + print(f"Loading DelineateAnything model ({model_variant})") + tic = time.time() + model = DelineateAnything( + model=model_variant, + patch_size=patch_size, + resize_factor=resize_factor, + max_detections=max_detections, + iou_threshold=model_iou_threshold, + conf_threshold=conf_threshold, + percentile_low=percentile_low, + percentile_high=percentile_high, + norm_constant=norm_constant, + device=str(device), + ) + print(f"Model loaded in {time.time() - tic:.2f}s") + + print("Creating dataloader") + tic = time.time() + ds = FTW( + root=dir, + countries=countries, + split=target_split, + load_boundaries=True, # 3-class masks for comparison + temporal_options="stacked", + ) + dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) + print(f"Created dataloader with {len(ds)} samples in {time.time() - tic:.2f}s") + + # Metrics for 3-class evaluation (0=bg, 1=field, 2=boundary) + metrics = MetricCollection( + [ + JaccardIndex( + task="multiclass", average="none", num_classes=3, ignore_index=3 + ), + Precision( + task="multiclass", average="none", num_classes=3, ignore_index=3 + ), + Recall( + task="multiclass", average="none", num_classes=3, ignore_index=3 + ), + ] + ).to(device) + + all_tps = 0 + all_fps = 0 + all_fns = 0 + + print("Running inference") + for batch in tqdm(dl): + images = batch["image"] # (B, C, H, W) + masks = batch["mask"].to(device) # (B, H, W) + + # Run inference + with torch.inference_mode(): + results_list = model(images) + + # Convert instance segmentation results to 3-class masks + pred_masks = [] + for results in results_list: + h, w = patch_size, patch_size + instance_mask = np.zeros((h, w), dtype=np.int32) + + if results.masks is not None: + result_masks = results.masks.data.cpu().numpy() + from scipy.ndimage import zoom + + for idx, mask in enumerate(result_masks, start=1): + if mask.shape != (h, w): + scale_y = h / mask.shape[0] + scale_x = w / mask.shape[1] + mask = zoom(mask, (scale_y, scale_x), order=0) + instance_mask[mask > 0.5] = idx + + # Convert to 3-class: 0=bg, 1=interior, 2=boundary + boundary_mask = np.zeros((h, w), dtype=np.uint8) + for idx in range(1, instance_mask.max() + 1): + instance_binary = instance_mask == idx + if not np.any(instance_binary): + continue + interior = binary_erosion(instance_binary, iterations=1) + boundary = instance_binary & ~interior + boundary_mask[interior] = 1 + boundary_mask[boundary] = 2 + + pred_masks.append(boundary_mask) + + pred_masks = torch.from_numpy(np.stack(pred_masks)).to(device) + + # Update metrics + metrics.update(pred_masks, masks) + + # Object-level metrics + pred_masks_np = pred_masks.cpu().numpy().astype(np.uint8) + masks_np = masks.cpu().numpy().astype(np.uint8) + + for i in range(len(pred_masks_np)): + # Convert to binary for object metrics (field vs background) + pred_binary = (pred_masks_np[i] > 0).astype(np.uint8) + mask_binary = (masks_np[i] > 0).astype(np.uint8) + tps, fps, fns = get_object_level_metrics( + mask_binary, pred_binary, iou_threshold=iou_threshold + ) + all_tps += tps + all_fps += fps + all_fns += fns + + # Compute metrics + results = metrics.compute() + pixel_level_iou = results["MulticlassJaccardIndex"][1].item() + pixel_level_precision = results["MulticlassPrecision"][1].item() + pixel_level_recall = results["MulticlassRecall"][1].item() + + boundary_iou = results["MulticlassJaccardIndex"][2].item() + boundary_precision = results["MulticlassPrecision"][2].item() + boundary_recall = results["MulticlassRecall"][2].item() + + if all_tps + all_fps > 0: + object_precision = all_tps / (all_tps + all_fps) + else: + object_precision = float("nan") + + if all_tps + all_fns > 0: + object_recall = all_tps / (all_tps + all_fns) + else: + object_recall = float("nan") + + if object_precision + object_recall > 0 and not ( + np.isnan(object_precision) or np.isnan(object_recall) + ): + object_f1 = ( + 2 * object_precision * object_recall / (object_precision + object_recall) + ) + else: + object_f1 = float("nan") + + print(f"\n{'='*60}") + print(f"DelineateAnything ({model_variant}) Results") + if norm_constant is not None: + print(f"Constant normalization: {norm_constant}") + else: + print(f"Percentile normalization: [{percentile_low}, {percentile_high}]") + print(f"{'='*60}") + print(f"Field pixel IoU: {pixel_level_iou:.4f}") + print(f"Field pixel precision: {pixel_level_precision:.4f}") + print(f"Field pixel recall: {pixel_level_recall:.4f}") + print(f"Boundary pixel IoU: {boundary_iou:.4f}") + print(f"Boundary pixel precision: {boundary_precision:.4f}") + print(f"Boundary pixel recall: {boundary_recall:.4f}") + print(f"Object precision: {object_precision:.4f}") + print(f"Object recall: {object_recall:.4f}") + print(f"Object F1: {object_f1:.4f}") + + country_str = ";".join(countries) + if set(countries) == set(FULL_DATA_COUNTRIES): + country_str = "all" + + if out is not None: + if not os.path.exists(out): + with open(out, "w") as f: + f.write( + "model,countries,percentile_low,percentile_high,norm_constant,resize_factor,conf_threshold,model_iou_threshold," + "field_pixel_iou,field_pixel_precision,field_pixel_recall," + "boundary_pixel_iou,boundary_pixel_precision,boundary_pixel_recall," + "object_precision,object_recall,object_f1\n" + ) + with open(out, "a") as f: + norm_const_str = str(norm_constant) if norm_constant is not None else "" + percentile_low_str = str(percentile_low) if percentile_low is not None else "" + percentile_high_str = str(percentile_high) if percentile_high is not None else "" + f.write( + f"{model_variant},{country_str},{percentile_low_str},{percentile_high_str},{norm_const_str},{resize_factor},{conf_threshold},{model_iou_threshold}," + f"{pixel_level_iou},{pixel_level_precision},{pixel_level_recall}," + f"{boundary_iou},{boundary_precision},{boundary_recall}," + f"{object_precision},{object_recall},{object_f1}\n" + ) + print(f"\nResults saved to {out}") diff --git a/ftw_tools/training/trainers.py b/ftw_tools/training/trainers.py old mode 100644 new mode 100755 index 98c0bd6..90ab0d1 --- a/ftw_tools/training/trainers.py +++ b/ftw_tools/training/trainers.py @@ -1,5 +1,6 @@ """Trainer for semantic segmentation.""" +import logging import warnings from typing import Any, Optional, Union @@ -20,6 +21,9 @@ from torchgeo.trainers.base import BaseTask from torchmetrics import MetricCollection from torchmetrics.classification import ( + BinaryJaccardIndex, + BinaryPrecision, + BinaryRecall, MulticlassJaccardIndex, MulticlassPrecision, MulticlassRecall, @@ -40,6 +44,9 @@ from .metrics import get_object_level_metrics from .utils import batch_corner_consensus_from_model +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + class CustomSemanticSegmentationTask(BaseTask): """Semantic Segmentation. @@ -66,7 +73,8 @@ def __init__( freeze_backbone: bool = False, freeze_decoder: bool = False, edge_agreement_loss: bool = False, - model_kwargs: dict[Any, Any] = dict(), + pretrained_checkpoint: Optional[str] = None, + model_kwargs: Optional[dict[Any, Any]] = None, ) -> None: """Initialize a new SemanticSegmentationTask instance. @@ -101,6 +109,10 @@ class and used with 'ce' loss. the segmentation head. edge_agreement_loss: If True, ignore non-edge pixels by remapping them to the reserved "unknown" class index before loss computation. + pretrained_checkpoint: Path to a checkpoint file from which to load + encoder and decoder weights. This is used for transfer learning from + edge pre-training. If provided, weights are loaded after model + initialization, overriding ImageNet or random weights. model_kwargs: Additional keyword arguments to pass to the model Warns: @@ -325,7 +337,7 @@ def configure_models(self) -> None: in_channels: int = self.hparams["in_channels"] num_classes: int = self.hparams["num_classes"] num_filters: int = self.hparams["num_filters"] - model_kwargs: dict[Any, Any] = self.hparams["model_kwargs"] + model_kwargs: dict[Any, Any] = self.hparams["model_kwargs"] or {} patch_weights: bool = self.hparams["patch_weights"] if model == "unet": @@ -407,6 +419,12 @@ def configure_models(self) -> None: if in_channels < 5 and model in ["fcsiamdiff", "fcsiamconc", "fcsiamavg"]: raise ValueError("FCSiam models require more than one input image.") + # Load encoder and decoder weights from a pre-trained checkpoint (e.g., edge pre-training) + pretrained_checkpoint = self.hparams.get("pretrained_checkpoint") + if pretrained_checkpoint is not None: + logger.info("Loading from checkpoint: %s", pretrained_checkpoint) + self._load_pretrained_weights(pretrained_checkpoint) + # Freeze backbone if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]: for param in self.model.encoder.parameters(): @@ -420,6 +438,63 @@ def configure_models(self) -> None: if patch_weights: self.transfer_weights(self.model, backbone) + def _load_pretrained_weights(self, checkpoint_path: str) -> None: + """Load encoder and decoder weights from a checkpoint file. + + Args: + checkpoint_path: Path to the checkpoint file (.ckpt). + """ + ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + state_dict = ckpt.get("state_dict", ckpt) + + # Extract encoder weights (keys start with "model.encoder.") + encoder_state = {} + for key, value in state_dict.items(): + if key.startswith("model.encoder."): + new_key = key.replace("model.encoder.", "") + encoder_state[new_key] = value + + # Extract decoder weights (keys start with "model.decoder.") + decoder_state = {} + for key, value in state_dict.items(): + if key.startswith("model.decoder."): + new_key = key.replace("model.decoder.", "") + decoder_state[new_key] = value + + if not encoder_state and not decoder_state: + raise ValueError( + f"No encoder or decoder weights found in checkpoint {checkpoint_path}. " + "Expected keys starting with 'model.encoder.' or 'model.decoder.'" + ) + + # Load encoder weights + if encoder_state: + result = self.model.encoder.load_state_dict(encoder_state, strict=False) + if result is None: + logger.info("Loaded encoder weights from %s.", checkpoint_path) + else: + missing, unexpected = result + logger.info( + "Loaded encoder weights from %s. Missing: %d, Unexpected: %d", + checkpoint_path, + len(missing), + len(unexpected), + ) + + # Load decoder weights + if decoder_state: + result = self.model.decoder.load_state_dict(decoder_state, strict=False) + if result is None: + logger.info("Loaded decoder weights from %s.", checkpoint_path) + else: + missing, unexpected = result + logger.info( + "Loaded decoder weights from %s. Missing: %d, Unexpected: %d", + checkpoint_path, + len(missing), + len(unexpected), + ) + def _log_per_class(self, metrics_dict, split: str): # metrics_dict like {"precision": tensor(C,), "recall": tensor(C,), "iou": tensor(C,)} for name, values in metrics_dict.items(): @@ -725,3 +800,262 @@ def transfer_weights(self, model, backbone): model.load_state_dict(model_dict) else: print("Due to mismatch in the Tensor size, unable to patch weights.") + + +class EdgePretrainingTask(BaseTask): + """Pre-training task for edge prediction. + + This task trains a segmentation model to predict binary edge masks from + satellite imagery. The pre-trained encoder can then be used as initialization + for the main segmentation task. + """ + + def __init__( + self, + model: str = "unet", + backbone: str = "efficientnet-b3", + weights: Optional[Union[WeightsEnum, str, bool]] = True, + in_channels: int = 8, + lr: float = 1e-3, + patience: int = 100, + model_kwargs: Optional[dict[Any, Any]] = None, + ) -> None: + """Initialize EdgePretrainingTask. + + Args: + model: Name of the segmentation model (e.g., "unet"). + backbone: Name of the encoder backbone (e.g., "efficientnet-b3"). + weights: Initial encoder weights. True for ImageNet, False/None for random. + in_channels: Number of input channels. + lr: Learning rate. + patience: Patience for cosine annealing scheduler. + model_kwargs: Additional keyword arguments for the model. + """ + self.weights = weights + super().__init__() + + def configure_losses(self) -> None: + """Initialize the loss criterion (BCE + Dice for edge prediction).""" + self.bce_loss = nn.BCEWithLogitsLoss() + self.dice_loss = smp.losses.DiceLoss(mode="binary", from_logits=True) + + def configure_metrics(self) -> None: + """Initialize the performance metrics for binary edge prediction.""" + base_metrics = { + "precision": BinaryPrecision(), + "recall": BinaryRecall(), + "iou": BinaryJaccardIndex(), + } + self.train_metrics = MetricCollection(base_metrics, prefix="train/") + self.val_metrics = self.train_metrics.clone(prefix="val/") + self.test_metrics = self.train_metrics.clone(prefix="test/") + + def configure_models(self) -> None: + """Initialize the segmentation model with 1 output class for binary edges.""" + model: str = self.hparams["model"] + backbone: str = self.hparams["backbone"] + weights = self.weights + in_channels: int = self.hparams["in_channels"] + model_kwargs: dict[Any, Any] = self.hparams["model_kwargs"] or {} + + if model == "unet": + self.model = smp.Unet( + encoder_name=backbone, + encoder_weights="imagenet" if weights is True else None, + in_channels=in_channels, + classes=1, + **model_kwargs, + ) + elif model == "deeplabv3+": + self.model = smp.DeepLabV3Plus( + encoder_name=backbone, + encoder_weights="imagenet" if weights is True else None, + in_channels=in_channels, + classes=1, + **model_kwargs, + ) + else: + raise ValueError( + f"Model '{model}' not supported for edge pretraining. Use 'unet' or 'deeplabv3+'." + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.model(x) + + def _log_edge_visualization( + self, x: Tensor, edge: Tensor, y_hat: Tensor, batch_idx: int + ) -> None: + """Log visualization of edge predictions to TensorBoard. + + Args: + x: Input image tensor (B, C, H, W). + edge: Ground truth edge mask (B, H, W). + y_hat: Raw model output logits (B, H, W). + batch_idx: Batch index for labeling. + """ + # Take first sample from batch + img = x[0].cpu().numpy() + gt_edge = edge[0].cpu().numpy() + pred_prob = torch.sigmoid(y_hat[0]).cpu().numpy() + pred_binary = (pred_prob > 0.5).astype(np.float32) + + # Check if we have stacked temporal windows (8 channels = 2x4 bands) + num_channels = img.shape[0] + has_two_windows = num_channels >= 8 + + if has_two_windows: + # Two temporal windows stacked + img1 = img[:3].transpose(1, 2, 0) + img2 = img[4:7].transpose(1, 2, 0) + num_panels = 5 + else: + # Single window + img1 = img[:3].transpose(1, 2, 0) + img2 = None + num_panels = 4 + + fig, axes = plt.subplots(1, num_panels, figsize=(num_panels * 4, 4)) + + # Use np.clip like FTW.plot() for consistency + axes[0].imshow(np.clip(img1, 0, 1)) + axes[0].set_title("Window B") + axes[0].axis("off") + + panel_idx = 1 + if has_two_windows: + axes[panel_idx].imshow(np.clip(img2, 0, 1)) + axes[panel_idx].set_title("Window A") + axes[panel_idx].axis("off") + panel_idx += 1 + + axes[panel_idx].imshow(gt_edge, cmap="gray", vmin=0, vmax=1) + axes[panel_idx].set_title("Ground Truth Edge") + axes[panel_idx].axis("off") + + axes[panel_idx + 1].imshow(pred_prob, cmap="gray", vmin=0, vmax=1) + axes[panel_idx + 1].set_title("Predicted Probability") + axes[panel_idx + 1].axis("off") + + axes[panel_idx + 2].imshow(pred_binary, cmap="gray", vmin=0, vmax=1) + axes[panel_idx + 2].set_title("Predicted Edge (>0.5)") + axes[panel_idx + 2].axis("off") + + plt.tight_layout() + + # Log to all available loggers + for log in self.loggers: + if hasattr(log, "experiment") and hasattr(log.experiment, "add_figure"): + log.experiment.add_figure( + f"edge_val/{batch_idx}", fig, global_step=self.global_step + ) + + plt.close(fig) + + def training_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute training loss for edge prediction. + + Args: + batch: Dictionary with "image" and "edge" tensors. + batch_idx: Index of this batch. + dataloader_idx: Index of the current dataloader. + + Returns: + The loss tensor. + """ + x = batch["image"] + # Binarize edges: edge > 0 means edge pixel + edge = (batch["edge"].squeeze(1) > 0).float() + + y_hat = self(x).squeeze(1) # Shape: (B, H, W) + + bce = self.bce_loss(y_hat, edge) + dice = self.dice_loss(y_hat.unsqueeze(1), edge.unsqueeze(1)) + loss = bce + dice + + self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("train/bce", bce, on_step=False, on_epoch=True) + self.log("train/dice", dice, on_step=False, on_epoch=True) + + preds = (torch.sigmoid(y_hat) > 0.5).long() + self.train_metrics.update(preds, edge.long()) + + return loss + + def validation_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Compute validation loss for edge prediction.""" + x = batch["image"] + edge = (batch["edge"].squeeze(1) > 0).float() + + y_hat = self(x).squeeze(1) + + bce = self.bce_loss(y_hat, edge) + dice = self.dice_loss(y_hat.unsqueeze(1), edge.unsqueeze(1)) + loss = bce + dice + + self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("val/bce", bce, on_step=False, on_epoch=True) + self.log("val/dice", dice, on_step=False, on_epoch=True) + + preds = (torch.sigmoid(y_hat) > 0.5).long() + self.val_metrics.update(preds, edge.long()) + + # Validation visualization for first few batches + if ( + batch_idx < 10 + and self.logger + and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") + ): + self._log_edge_visualization(x, edge, y_hat, batch_idx) + + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + """Compute test loss for edge prediction.""" + x = batch["image"] + edge = (batch["edge"].squeeze(1) > 0).float() + + y_hat = self(x).squeeze(1) + + bce = self.bce_loss(y_hat, edge) + dice = self.dice_loss(y_hat.unsqueeze(1), edge.unsqueeze(1)) + loss = bce + dice + + self.log("test/loss", loss) + + preds = (torch.sigmoid(y_hat) > 0.5).long() + self.test_metrics.update(preds, edge.long()) + + def configure_optimizers( + self, + ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": + """Initialize optimizer and learning rate scheduler.""" + optimizer = AdamW(self.parameters(), lr=self.hparams["lr"], amsgrad=True) + scheduler = CosineAnnealingLR( + optimizer, T_max=self.hparams["patience"], eta_min=1e-6 + ) + return { + "optimizer": optimizer, + "lr_scheduler": {"scheduler": scheduler, "monitor": "val/loss"}, + } + + def on_train_epoch_end(self) -> None: + """Log per-epoch training metrics.""" + computed = self.train_metrics.compute() + self.log_dict(computed, on_step=False, on_epoch=True) + self.train_metrics.reset() + + def on_validation_epoch_end(self) -> None: + """Log per-epoch validation metrics.""" + computed = self.val_metrics.compute() + self.log_dict(computed, on_step=False, on_epoch=True) + self.val_metrics.reset() + + def on_test_epoch_end(self) -> None: + """Log per-epoch test metrics.""" + computed = self.test_metrics.compute() + self.log_dict(computed, on_step=False, on_epoch=True) + self.test_metrics.reset()