Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 40 additions & 42 deletions ftw_tools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,43 +29,37 @@
# All parameters are meant to use underscores as separator for words.


class ModelOrCheckpointParamType(click.ParamType):
"""
A custom Click parameter type that accepts either:
1. A model name from MODEL_REGISTRY, or
2. A path to a .ckpt checkpoint file
"""

class ModelOrCheckpointParamType(click.Choice):
name = "model_or_checkpoint"

def __init__(self):
super().__init__(MODEL_REGISTRY.keys())
self._path = click.Path(exists=True, dir_okay=False)

def get_metavar(self, param, ctx):
return f"{super().get_metavar(param, ctx)} or PATH"

def get_missing_message(self, param):
return (
super().get_missing_message(param) + " or provide a PATH to a .ckpt file."
)

def convert(self, value, param, ctx):
"""
Convert and validate the parameter value.

Args:
value: The raw parameter value from CLI
param: The parameter object
ctx: The click context

Returns:
The validated value (unchanged)

Raises:
click.BadParameter if validation fails
"""
# Check if it's a model name from the registry
if value in MODEL_REGISTRY.keys():
return value

# Check if it's a checkpoint path
if value.endswith(".ckpt"):
if os.path.exists(value):
return value
else:
self.fail(f"Checkpoint file '{value}' does not exist.", param, ctx)
try:
return super().convert(value, param, ctx)
except click.BadParameter:
pass

try:
file = self._path.convert(value, param, ctx)
# Check if it's a checkpoint path
if file.endswith(".ckpt"):
return file
except click.BadParameter:
pass

# If neither, provide helpful error message
available_models = ", ".join(MODEL_REGISTRY.keys())
available_models = ", ".join(self.choices)
self.fail(
f"'{value}' is not a valid model name or checkpoint path. "
f"Valid model names are: {available_models}. "
Expand All @@ -74,6 +68,11 @@ def convert(self, value, param, ctx):
ctx,
)

def shell_complete(self, ctx, param, incomplete):
items = super().shell_complete(ctx, param, incomplete)
items.extend(self._path.shell_complete(ctx, param, incomplete))
return items


# Common parameter definitions for shared CLI options
def common_bbox_option():
Expand Down Expand Up @@ -224,6 +223,7 @@ def data_download(out, clean_download, countries, no_unpack):
type=click.Path(exists=True, dir_okay=True, file_okay=False),
default="./data",
required=False,
nargs=1,
)
def data_unpack(input):
from ftw_tools.download.unpack import unpack
Expand Down Expand Up @@ -656,17 +656,15 @@ def inference_download(

@inference.command(
"run",
help="Run inference on the stacked Sentinel-2 L2A satellite images specified via INPUT.",
help="Run inference on the stacked Sentinel-2 L2A satellite image specified via INPUT.",
)
@click.argument("input", type=click.Path(exists=True), required=True)
@click.argument("input", type=click.Path(exists=True), required=True, nargs=1)
@click.option(
"--model",
"-m",
type=ModelOrCheckpointParamType(),
required=True,
help="Short model name from the registry (one of: "
+ ", ".join(MODEL_REGISTRY.keys())
+ ") OR path to a checkpoint file (.ckpt).",
help="Short model name from the registry OR path to a checkpoint file (.ckpt).",
)
@click.option(
"--out",
Expand Down Expand Up @@ -776,9 +774,9 @@ def inference_run(

@inference.command(
"run-instance-segmentation",
help="Run an instance segmentation model inference on a single Sentinel-2 L2A satellite images specified via INPUT.",
help="Run an instance segmentation model inference on a single Sentinel-2 L2A satellite image specified via INPUT.",
)
@click.argument("input", type=click.Path(exists=True), required=True)
@click.argument("input", type=click.Path(exists=True), required=True, nargs=1)
@click.option(
"--model",
"-m",
Expand Down Expand Up @@ -976,7 +974,7 @@ def inference_run_instance_segmentation(
"instance-segmentation-all",
help="Run all inference instance segmentation commands from download and inference.",
)
@click.argument("input", type=str, required=True)
@click.argument("input", type=str, required=True, nargs=1)
@click.option(
"--bbox",
type=str,
Expand Down Expand Up @@ -1211,7 +1209,7 @@ def inference_run_instance_segmentation_all(
"polygonize",
help="Polygonize the output from inference for the raster image given via INPUT. Results are in the CRS of the given raster image.",
)
@click.argument("input", type=click.Path(exists=True), required=True)
@click.argument("input", type=click.Path(exists=True), required=True, nargs=1)
@click.option(
"--out",
"-o",
Expand Down Expand Up @@ -1363,7 +1361,7 @@ def inference_polygonize(
@inference.command(
"filter-by-lulc", help="Filter the output raster in GeoTIFF format by LULC mask."
)
@click.argument("input", type=click.Path(exists=True), required=True)
@click.argument("input", type=click.Path(exists=True), required=True, nargs=1)
@click.option(
"--out",
"-o",
Expand Down