Skip to content
Open

Dev #248

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion configs/release/check_middle_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def main(args: argparse.Namespace):

for batch in tqdm(dl, desc=f"Processing {country} batches", leave=False):
images = batch["image"]
masks = batch["mask"]
if model_type in ["fcsiamdiff", "fcsiamconc", "fcsiamavg"]:
images = rearrange(images, "b (t c) h w -> b t c h w", t=2)

Expand Down Expand Up @@ -141,13 +142,28 @@ def main(args: argparse.Namespace):
)
consensus_score = consensus.mean()

# Extract central region from ground truth mask
gt_central = masks[i, padding : padding + overlap_size, padding : padding + overlap_size].numpy()

# Compute accuracy for each corner's prediction
acc1 = (hard_output[0] == gt_central).mean()
acc2 = (hard_output[1] == gt_central).mean()
acc3 = (hard_output[2] == gt_central).mean()
acc4 = (hard_output[3] == gt_central).mean()
mean_accuracy = (acc1 + acc2 + acc3 + acc4) / 4

# Store result for this patch
all_results.append(
{
"model_checkpoint": args.model,
"country": country,
"patch_idx": patch_idx,
"consensus_score": consensus_score,
"accuracy_corner1": acc1,
"accuracy_corner2": acc2,
"accuracy_corner3": acc3,
"accuracy_corner4": acc4,
"mean_accuracy": mean_accuracy,
"split": args.split,
}
)
Expand All @@ -167,10 +183,13 @@ def main(args: argparse.Namespace):
print(f"Min consensus: {df['consensus_score'].min():.4f}")
print(f"Max consensus: {df['consensus_score'].max():.4f}")
print(f"Median consensus: {df['consensus_score'].median():.4f}")
print(
f"Mean accuracy: {df['mean_accuracy'].mean():.4f} +/- {df['mean_accuracy'].std():.4f}"
)

print("\nPer-country statistics:")
country_stats = (
df.groupby("country")["consensus_score"].agg(["count", "mean", "std"]).round(4)
df.groupby("country")[["consensus_score", "mean_accuracy"]].agg(["count", "mean", "std"]).round(4)
)
print(country_stats)

Expand Down
5 changes: 3 additions & 2 deletions ftw_tools/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def run(
compute_consensus: bool = False,
preprocess_fn: Callable = default_preprocess,
):
if save_scores and compute_consensus:
raise ValueError("save_scores and compute_consensus are mutually exclusive.")
assert not (save_scores and compute_consensus), (
"save_scores and compute_consensus are mutually exclusive."
)

device, transform, input_shape, patch_size, stride, padding = setup_inference(
input, out, gpu, patch_size, padding, overwrite, mps_mode
Expand Down
84 changes: 78 additions & 6 deletions ftw_tools/inference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def __init__(
max_detections: int = 100,
iou_threshold: float = 0.3,
conf_threshold: float = 0.05,
percentile_low: float | None = 0.01,
percentile_high: float | None = 0.99,
norm_constant: float | None = None,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> None:
"""Initialize the DelineateAnything model.
Expand All @@ -136,9 +139,26 @@ def __init__(
max_detections: Maximum number of detections per image.
iou_threshold: Intersection over Union threshold for filtering predictions.
conf_threshold: Confidence threshold for filtering predictions.
percentile_low: Lower percentile for per-channel normalization (default: 0.01).
Mutually exclusive with norm_constant.
percentile_high: Upper percentile for per-channel normalization (default: 0.99).
Mutually exclusive with norm_constant.
norm_constant: If provided, divide by this value and clip to [0, 1] instead of
using percentile normalization. Mutually exclusive with percentile_low/percentile_high.
device: Device to run the model on, either "cuda" or "cpu".

Raises:
ValueError: If both norm_constant and percentile options are specified.
"""
super().__init__()

# Validate mutually exclusive normalization options
if norm_constant is not None and (percentile_low is not None or percentile_high is not None):
raise ValueError(
"norm_constant is mutually exclusive with percentile_low/percentile_high. "
"Set percentile_low=None and percentile_high=None when using norm_constant."
)

self.patch_size = (
(patch_size, patch_size) if isinstance(patch_size, int) else patch_size
)
Expand All @@ -149,21 +169,61 @@ def __init__(
self.max_detections = max_detections
self.iou_threshold = iou_threshold
self.conf_threshold = conf_threshold
self.percentile_low = percentile_low
self.percentile_high = percentile_high
self.norm_constant = norm_constant
self.device = device
from ultralytics import YOLO

self.model = YOLO(self.checkpoints[model]).to(device)
self.model.eval()
self.model.fuse()
self.transforms = nn.Sequential(
T.Lambda(lambda x: x.unsqueeze(dim=0) if x.ndim == 3 else x),
T.Lambda(lambda x: x[:, :3, ...]),
T.Lambda(lambda x: x / 3000.0),
T.Lambda(lambda x: x.clip(0.0, 1.0)),
T.Resize(self.image_size, interpolation=T.InterpolationMode.BILINEAR),
T.ConvertImageDtype(torch.float32),
).to(device)

def _normalize(self, x: torch.Tensor) -> torch.Tensor:
"""Apply normalization based on configured method.

Uses either constant normalization (if norm_constant is set) or
per-channel percentile normalization (if percentile_low/high are set).

Args:
x: Input tensor of shape (B, C, H, W).

Returns:
Normalized tensor clipped to [0, 1].
"""
if self.norm_constant is not None:
# Simple constant normalization: divide and clip
return (x / self.norm_constant).clip(0.0, 1.0)

# Per-channel percentile normalization
# At this point percentile_low and percentile_high are guaranteed to be non-None
assert self.percentile_low is not None and self.percentile_high is not None

# x shape: (B, C, H, W)
B, C, H, W = x.shape
# Flatten spatial dimensions for percentile computation
x_flat = x.view(B, C, -1) # (B, C, H*W)

# Compute percentiles per channel per batch
p_low = torch.quantile(x_flat.float(), self.percentile_low, dim=2, keepdim=True) # (B, C, 1)
p_high = torch.quantile(x_flat.float(), self.percentile_high, dim=2, keepdim=True) # (B, C, 1)

# Reshape for broadcasting
p_low = p_low.view(B, C, 1, 1)
p_high = p_high.view(B, C, 1, 1)

# Avoid division by zero
denom = p_high - p_low
denom = torch.where(denom == 0, torch.ones_like(denom), denom)

# Normalize and clip
x_norm = (x - p_low) / denom
return x_norm.clip(0.0, 1.0)

@staticmethod
def polygonize(
result: "ultralytics.engine.results.Results",
Expand Down Expand Up @@ -211,12 +271,24 @@ def __call__(
"""Forward pass through the model.

Args:
image: The input image tensor, expected to be in the format (B, C, H, W).
image: The input image tensor, expected to be in the format (B, C, H, W) or (C, H, W).

Returns:
A list of results containing the model predictions.
"""
image = self.transforms(image.to(self.device))
# Add batch dimension if needed
if image.ndim == 3:
image = image.unsqueeze(0)

# Select first 3 channels (RGB)
image = image[:, :3, ...]

# Apply normalization (constant or percentile-based)
image = self._normalize(image.to(self.device))

# Apply remaining transforms (resize, dtype conversion)
image = self.transforms(image)

results = self.model.predict(
image,
conf=self.conf_threshold,
Expand Down
Loading
Loading