diff --git a/end_to_end/imagenet/README.md b/end_to_end/imagenet/README.md index cae16736..2c6a3fd9 100644 --- a/end_to_end/imagenet/README.md +++ b/end_to_end/imagenet/README.md @@ -40,7 +40,11 @@ This will generate mobilenet_v2.tflite in the current directory. Supported architectures include: `mobilenet_v2`, `resnet18`, `resnet34`, `resnet50`, `resnet101`, `resnet152`, `efficientnet_b0` through `efficientnet_b7`, `efficientnet_v2_s`, `efficientnet_v2_m`, -`efficientnet_v2_l`. +`alexnet`, `convnext_tiny`, `convnext_small`, `convnext_base`, `convnext_large`, +`vgg11`, `vgg11_bn`, `vgg13`, `vgg13_bn`, `vgg16`, `vgg16_bn`, `vgg19`, `vgg19_bn`, +`efficientnet_v2_l`, `shufflenet_v2_x0_5`, `shufflenet_v2_x1_0`, +`shufflenet_v2_x1_5`, `shufflenet_v2_x2_0`, `squeezenet1_0`, `squeezenet1_1`, +`inception_v3`. If you choose a different architecture, the default output name matches it (for example, `resnet18.tflite`). diff --git a/end_to_end/imagenet/eval.py b/end_to_end/imagenet/eval.py new file mode 100644 index 00000000..e15407a9 --- /dev/null +++ b/end_to_end/imagenet/eval.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# Copyright 2026 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluate a LiteRT model on ImageNet-1K validation split.""" + +import argparse +import os +import sys + +import numpy as np +from ai_edge_litert.compiled_model import CompiledModel +from PIL import Image + +from datasets import load_dataset +import evaluate +from tqdm import tqdm + +import main as imagenet_main + + +def _default_model_path() -> str: + return os.path.join(os.getcwd(), "mobilenet_v2.tflite") + + +def _preprocess_image( + image: Image.Image, + channels: int, + resize_size: int, + crop_height: int, + crop_width: int, + mean: np.ndarray, + std: np.ndarray, + resample: int, + channels_first: bool, +) -> np.ndarray: + if channels != 3: + raise ValueError(f"Expected 3 channels, got {channels}") + if resize_size <= 0 or crop_height <= 0 or crop_width <= 0: + raise ValueError( + f"Invalid resize/crop size: resize={resize_size}, crop={crop_height}x{crop_width}" + ) + image = image.convert("RGB") + width, height = image.size + if width < height: + new_width = resize_size + new_height = int(round(height * resize_size / width)) + else: + new_height = resize_size + new_width = int(round(width * resize_size / height)) + image = image.resize((new_width, new_height), resample) + left = int(round((new_width - crop_width) / 2.0)) + top = int(round((new_height - crop_height) / 2.0)) + image = image.crop((left, top, left + crop_width, top + crop_height)) + array = np.asarray(image, dtype=np.int32) + array = array.astype(np.float32) / 255.0 + array = (array - mean) / std + if channels_first: + return np.transpose(array, (2, 0, 1)) + return array + + +def _topk_indices(scores: np.ndarray, k: int) -> np.ndarray: + flat = scores.reshape(-1) + if k <= 1: + return np.array([int(np.argmax(flat))], dtype=np.int64) + if flat.size <= k: + return np.argsort(flat)[::-1] + idx = np.argpartition(flat, -k)[-k:] + return idx[np.argsort(flat[idx])[::-1]] + + +def _parse_args(argv: list[str]): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default=_default_model_path()) + parser.add_argument( + "--arch", + default=None, + help=( + "Optional model architecture name to select preprocessing. " + "If set, overrides inference based on --model filename." + ), + ) + parser.add_argument( + "--max_samples", + type=int, + default=0, + help="If set, stop after this many samples.", + ) + return parser.parse_args(argv) + + +def _iter_dataset(dataset, max_samples: int): + count = 0 + for item in dataset: + yield item + count += 1 + if max_samples and count >= max_samples: + break + + +def main(argv: list[str]) -> int: + args = _parse_args(argv) + + if not os.path.exists(args.model): + raise FileNotFoundError(f"Model not found: {args.model}") + + model = CompiledModel.from_file(args.model) + signature_index = 0 + channels = 3 + + input_height, input_width, channels_first = imagenet_main._infer_input_size( + model, signature_index + ) + preprocess_model_key = args.arch if args.arch else args.model + preprocess = imagenet_main._pick_preprocess_config( + preprocess_model_key, input_height, input_width + ) + layout = "NCHW" if channels_first else "NHWC" + print( + "Model input:", + f"{input_height}x{input_width}", + layout, + f"resize={preprocess['resize_size']}", + f"crop={preprocess['crop_height']}x{preprocess['crop_width']}", + ) + + input_buffers = model.create_input_buffers(signature_index) + output_buffers = model.create_output_buffers(signature_index) + output_requirements = model.get_output_buffer_requirements(0, signature_index) + + output_dtype = imagenet_main._pick_output_dtype(output_requirements) + buffer_size = output_requirements.get("buffer_size", 0) + itemsize = np.dtype(output_dtype).itemsize + output_size = buffer_size // itemsize if itemsize else buffer_size + if output_size == 0: + raise ValueError("Output buffer size is zero") + output_offset = 1 if output_size == 1001 else 0 + + dataset = load_dataset( + "imagenet-1k", + split="validation", + token=True, + ) + + accuracy_metric = evaluate.load("accuracy") + correct_top5 = 0 + total = 0 + + total_hint = len(dataset) + iterator = _iter_dataset(dataset, args.max_samples) + for example in tqdm(iterator, total=total_hint, unit="img"): + image = example["image"] + label = int(example["label"]) + output_offset + + input_array = _preprocess_image( + image, + channels, + preprocess["resize_size"], + preprocess["crop_height"], + preprocess["crop_width"], + preprocess["mean"], + preprocess["std"], + preprocess["resample"], + channels_first, + ) + + input_buffers[0].write(input_array) + model.run_by_index(signature_index, input_buffers, output_buffers) + + output_array = imagenet_main._read_output(output_buffers[0], output_requirements) + scores = output_array.reshape(-1) + pred = int(np.argmax(scores)) + top5 = _topk_indices(scores, 5) + + accuracy_metric.add_batch(predictions=[pred], references=[label]) + if label in top5: + correct_top5 += 1 + total += 1 + + results = accuracy_metric.compute() + top1 = float(results.get("accuracy", 0.0)) + top5 = float(correct_top5 / total) if total else 0.0 + + print(f"Samples: {total}") + print(f"Top-1 accuracy: {top1:.6f}") + print(f"Top-5 accuracy: {top5:.6f}") + if output_offset: + print("Note: model output size is 1001; labels were offset by +1.") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) diff --git a/end_to_end/imagenet/eval_evaluate.py b/end_to_end/imagenet/eval_evaluate.py new file mode 100644 index 00000000..e73d3868 --- /dev/null +++ b/end_to_end/imagenet/eval_evaluate.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +# Copyright 2026 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluate a LiteRT vision model on ImageNet-1K validation.""" + +from __future__ import annotations + +import argparse +import os +import sys +from typing import Iterable + +import evaluate +import numpy as np +import transformers +from ai_edge_litert.compiled_model import CompiledModel +from ai_edge_litert.hardware_accelerator import HardwareAccelerator +from datasets import load_dataset +from PIL import Image +from imagenet_preprocessing import ( + infer_input_size, + pick_preprocess_config, + preprocess_image, +) + + +_LITERT_TYPE_TO_NP = { + 1: np.float32, # kLiteRtElementTypeFloat32 + 9: np.int8, # kLiteRtElementTypeInt8 + 3: np.uint8, # kLiteRtElementTypeUInt8 + 2: np.int32, # kLiteRtElementTypeInt32 +} + + +def _default_model_path() -> str: + return os.path.join(os.getcwd(), "mobilenet_v2.tflite") + + +def _pick_output_dtype(requirements: dict) -> np.dtype: + supported = requirements.get("supported_types", []) + for type_id in (1, 9, 3, 2): + if type_id in supported: + return _LITERT_TYPE_TO_NP[type_id] + if supported: + return _LITERT_TYPE_TO_NP.get(supported[0], np.float32) + return np.float32 + + +def _read_output(buffer, requirements: dict) -> np.ndarray: + output_dtype = _pick_output_dtype(requirements) + buffer_size = requirements.get("buffer_size", 0) + itemsize = np.dtype(output_dtype).itemsize + num_elements = buffer_size // itemsize if itemsize else buffer_size + if num_elements == 0: + raise ValueError("Output buffer size is zero") + return buffer.read(num_elements, output_dtype) + + +def _softmax(scores: np.ndarray) -> np.ndarray: + scores = scores.astype(np.float32, copy=False) + max_score = np.max(scores) + exp_scores = np.exp(scores - max_score) + return exp_scores / np.sum(exp_scores) + + +class LiteRTImageClassifier: + def __init__(self, model_path: str, labels: list[str], cpu_only: bool = False): + self.task = "image-classification" + accel = ( + HardwareAccelerator.CPU + if cpu_only + else (HardwareAccelerator.GPU | HardwareAccelerator.CPU) + ) + self._model = CompiledModel.from_file(model_path, hardware_accel=accel) + self._signature_index = 0 + self._channels = 3 + self._labels = labels + self._input_height, self._input_width, self._channels_first = infer_input_size( + self._model, self._signature_index + ) + self._preprocess = pick_preprocess_config( + model_path, self._input_height, self._input_width + ) + self._input_buffers = self._model.create_input_buffers(self._signature_index) + self._output_buffers = self._model.create_output_buffers(self._signature_index) + self._output_requirements = self._model.get_output_buffer_requirements( + 0, self._signature_index + ) + layout = "NCHW" if self._channels_first else "NHWC" + print( + "Model input:", + f"{self._input_height}x{self._input_width}", + layout, + f"resize={self._preprocess['resize_size']}", + f"crop={self._preprocess['crop_height']}x{self._preprocess['crop_width']}", + ) + + def __call__(self, inputs: Iterable[Image.Image]): + outputs = [] + for image in inputs: + array = preprocess_image( + image, + self._channels, + self._preprocess["resize_size"], + self._preprocess["crop_height"], + self._preprocess["crop_width"], + self._preprocess["mean"], + self._preprocess["std"], + self._preprocess["resample"], + self._channels_first, + ) + self._input_buffers[0].write(array) + self._model.run_by_index( + self._signature_index, self._input_buffers, self._output_buffers + ) + output_array = _read_output( + self._output_buffers[0], self._output_requirements + ) + probs = _softmax(output_array.reshape(-1)) + topk = np.argsort(probs)[-5:][::-1] + outputs.append( + [ + { + "label": self._labels[idx] if self._labels else str(idx), + "score": float(probs[idx]), + } + for idx in topk + ] + ) + return outputs + + +def _parse_args(argv: list[str]): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default=_default_model_path()) + parser.add_argument( + "--max_samples", + type=int, + default=500, + help="Evaluate at most this many samples (0 means full validation set).", + ) + parser.add_argument( + "--cpu_only", + action="store_true", + help="Run evaluation on CPU only.", + ) + return parser.parse_args(argv) + + +def main(argv: list[str]) -> int: + args = _parse_args(argv) + if not os.path.exists(args.model): + raise FileNotFoundError(f"Model not found: {args.model}") + if not hasattr(transformers, "TFPreTrainedModel"): + transformers.TFPreTrainedModel = type("_TFPreTrainedModel", (), {}) + dataset = load_dataset("imagenet-1k", split="validation", token=True) + if args.max_samples: + dataset = dataset.select(range(min(args.max_samples, len(dataset)))) + label_names = None + try: + label_names = dataset.features["label"].names + except Exception: + label_names = None + label_mapping = None + if label_names: + label_mapping = {name: idx for idx, name in enumerate(label_names)} + pipeline = LiteRTImageClassifier(args.model, label_names or [], cpu_only=args.cpu_only) + evaluator = evaluate.evaluator("image-classification") + results = evaluator.compute( + model_or_pipeline=pipeline, + data=dataset, + metric="accuracy", + label_mapping=label_mapping, + ) + print(results) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) diff --git a/end_to_end/imagenet/imagenet_preprocessing.py b/end_to_end/imagenet/imagenet_preprocessing.py new file mode 100644 index 00000000..c4874e34 --- /dev/null +++ b/end_to_end/imagenet/imagenet_preprocessing.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# Copyright 2026 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared ImageNet preprocessing helpers for LiteRT vision eval/classify.""" + +from __future__ import annotations + +import os + +import numpy as np +from PIL import Image + + +def infer_input_size(model, signature_index: int) -> tuple[int, int, bool]: + """Infer model input HxW and layout (channels_first) for preprocessing.""" + default_size = (224, 224) + default_channels_first = True + try: + signature = model.get_signature_by_index(signature_index) + signature_key = signature.get("key") if isinstance(signature, dict) else None + if signature_key: + details = model.get_input_tensor_details(signature_key) + if isinstance(details, dict) and details: + first = next(iter(details.values())) + shape = first.get("shape") if isinstance(first, dict) else None + if shape and len(shape) >= 3: + if len(shape) == 4: + if shape[1] == 3: + return int(shape[2]), int(shape[3]), True + if shape[-1] == 3: + return int(shape[1]), int(shape[2]), False + if len(shape) == 3: + if shape[0] == 3: + return int(shape[1]), int(shape[2]), True + if shape[-1] == 3: + return int(shape[0]), int(shape[1]), False + except Exception: + pass + try: + requirements = model.get_input_buffer_requirements(0, signature_index) + except Exception: + return default_size[0], default_size[1], default_channels_first + dims = ( + requirements.get("dimensions") + or requirements.get("shape") + or requirements.get("dims") + ) + if not dims: + return default_size[0], default_size[1], default_channels_first + try: + dims = [int(dim) for dim in dims] + except Exception: + return default_size[0], default_size[1], default_channels_first + + if len(dims) == 4: + if dims[1] == 3: + return dims[2], dims[3], True + if dims[-1] == 3: + return dims[1], dims[2], False + if dims[0] in (1, -1): + dims = dims[1:] + + if len(dims) == 3: + if dims[0] == 3: + return dims[1], dims[2], True + if dims[-1] == 3: + return dims[0], dims[1], False + + return default_size[0], default_size[1], default_channels_first + + +def pick_preprocess_config(model_path: str, input_height: int, input_width: int) -> dict: + def _fit_to_model_input( + resize_size: int, crop_height: int, crop_width: int, resample: int + ) -> dict: + ch = input_height if input_height > 0 else crop_height + cw = input_width if input_width > 0 else crop_width + if ch != crop_height or cw != crop_width: + resize_size = int(round(max(ch, cw) / 0.875)) + crop_height = ch + crop_width = cw + return { + "resize_size": resize_size, + "crop_height": crop_height, + "crop_width": crop_width, + "mean": np.array([0.485, 0.456, 0.406], dtype=np.float32), + "std": np.array([0.229, 0.224, 0.225], dtype=np.float32), + "resample": resample, + } + + model_name = os.path.basename(model_path).lower() + efficientnet_b_cfg = { + "efficientnet_b0": (256, 224, 224, Image.BICUBIC), + "efficientnet_b1": (255, 240, 240, Image.BILINEAR), + "efficientnet_b2": (288, 288, 288, Image.BICUBIC), + "efficientnet_b3": (320, 300, 300, Image.BICUBIC), + "efficientnet_b4": (384, 380, 380, Image.BICUBIC), + "efficientnet_b5": (456, 456, 456, Image.BICUBIC), + "efficientnet_b6": (528, 528, 528, Image.BICUBIC), + "efficientnet_b7": (600, 600, 600, Image.BICUBIC), + } + for arch, (resize, crop_h, crop_w, resample) in efficientnet_b_cfg.items(): + if arch in model_name: + return _fit_to_model_input(resize, crop_h, crop_w, resample) + + if "efficientnet_v2_s" in model_name: + return { + "resize_size": 384, + "crop_height": 384, + "crop_width": 384, + "mean": np.array([0.485, 0.456, 0.406], dtype=np.float32), + "std": np.array([0.229, 0.224, 0.225], dtype=np.float32), + "resample": Image.BILINEAR, + } + if "efficientnet_v2_m" in model_name: + return { + "resize_size": 480, + "crop_height": 480, + "crop_width": 480, + "mean": np.array([0.485, 0.456, 0.406], dtype=np.float32), + "std": np.array([0.229, 0.224, 0.225], dtype=np.float32), + "resample": Image.BILINEAR, + } + if "efficientnet_v2_l" in model_name: + return { + "resize_size": 480, + "crop_height": 480, + "crop_width": 480, + "mean": np.array([0.5, 0.5, 0.5], dtype=np.float32), + "std": np.array([0.5, 0.5, 0.5], dtype=np.float32), + "resample": Image.BICUBIC, + } + + crop_height = input_height if input_height > 0 else 224 + crop_width = input_width if input_width > 0 else 224 + resize_size = int(round(max(crop_height, crop_width) / 0.875)) + return { + "resize_size": resize_size, + "crop_height": crop_height, + "crop_width": crop_width, + "mean": np.array([0.485, 0.456, 0.406], dtype=np.float32), + "std": np.array([0.229, 0.224, 0.225], dtype=np.float32), + "resample": Image.BILINEAR, + } + + +def preprocess_image( + image: Image.Image, + channels: int, + resize_size: int, + crop_height: int, + crop_width: int, + mean: np.ndarray, + std: np.ndarray, + resample: int, + channels_first: bool, +) -> np.ndarray: + if channels != 3: + raise ValueError(f"Expected 3 channels, got {channels}") + if resize_size <= 0 or crop_height <= 0 or crop_width <= 0: + raise ValueError( + f"Invalid resize/crop size: resize={resize_size}, crop={crop_height}x{crop_width}" + ) + image = image.convert("RGB") + width, height = image.size + if width < height: + new_width = resize_size + new_height = int(round(height * resize_size / width)) + else: + new_height = resize_size + new_width = int(round(width * resize_size / height)) + image = image.resize((new_width, new_height), resample) + left = int(round((new_width - crop_width) / 2.0)) + top = int(round((new_height - crop_height) / 2.0)) + image = image.crop((left, top, left + crop_width, top + crop_height)) + array = np.asarray(image, dtype=np.int32) + array = array.astype(np.float32) / 255.0 + array = (array - mean) / std + if channels_first: + return np.transpose(array, (2, 0, 1)) + return array + + +def load_image( + image_path: str, + channels: int, + resize_size: int, + crop_height: int, + crop_width: int, + mean: np.ndarray, + std: np.ndarray, + resample: int, + channels_first: bool, +) -> np.ndarray: + image = Image.open(image_path).convert("RGB") + return preprocess_image( + image=image, + channels=channels, + resize_size=resize_size, + crop_height=crop_height, + crop_width=crop_width, + mean=mean, + std=std, + resample=resample, + channels_first=channels_first, + ) diff --git a/end_to_end/imagenet/main.py b/end_to_end/imagenet/main.py index 4805a7a8..9f08c39f 100644 --- a/end_to_end/imagenet/main.py +++ b/end_to_end/imagenet/main.py @@ -23,7 +23,7 @@ import sys import numpy as np from ai_edge_litert.compiled_model import CompiledModel -from PIL import Image +from imagenet_preprocessing import infer_input_size, load_image, pick_preprocess_config _LITERT_TYPE_TO_NP = { @@ -42,137 +42,6 @@ def _default_label_path(filename: str) -> str: return os.path.join(os.getcwd(), filename) -# https://docs.pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html -def _load_image( - image_path: str, - channels: int, - resize_size: int, - crop_height: int, - crop_width: int, - mean: np.ndarray, - std: np.ndarray, - resample: int, -) -> np.ndarray: - # Torchvision ImageNet models assume this preprocessing (resize/crop plus - # mean/std normalization from torchvision docs). If your model was converted - # from a different training pipeline, update these constants and transforms or - # you may get poor accuracy. - if channels != 3: - raise ValueError(f"Expected 3 channels, got {channels}") - if resize_size <= 0 or crop_height <= 0 or crop_width <= 0: - raise ValueError( - f"Invalid resize/crop size: resize={resize_size}, crop={crop_height}x{crop_width}" - ) - image = Image.open(image_path).convert("RGB") - width, height = image.size - if width < height: - new_width = resize_size - new_height = int(round(height * resize_size / width)) - else: - new_height = resize_size - new_width = int(round(width * resize_size / height)) - image = image.resize((new_width, new_height), resample) - left = int(round((new_width - crop_width) / 2.0)) - top = int(round((new_height - crop_height) / 2.0)) - image = image.crop((left, top, left + crop_width, top + crop_height)) - array = np.asarray(image, dtype=np.int32) - array = array.astype(np.float32) / 255.0 - array = (array - mean) / std - array = np.transpose(array, (2, 0, 1)) - return array - -# Used in the classify path -def _infer_input_size(model, signature_index: int) -> tuple[int, int, bool]: - """Infer the model's input HxW and layout (channels_first) for preprocessing.""" - default_size = (224, 224) - default_channels_first = True - # Try to infer HxW/layout from the compiled model input metadata; fall back. - try: - requirements = model.get_input_buffer_requirements(0, signature_index) - except Exception: - return default_size[0], default_size[1], default_channels_first - dims = ( - requirements.get("dimensions") - or requirements.get("shape") - or requirements.get("dims") - ) - if not dims: - return default_size[0], default_size[1], default_channels_first - try: - dims = [int(dim) for dim in dims] - except Exception: - return default_size[0], default_size[1], default_channels_first - - # Handle common NCHW/NHWC shapes, tolerating unknown batch or dim values. - if len(dims) == 4: - if dims[1] == 3: - return dims[2], dims[3], True # NCHW - if dims[-1] == 3: - return dims[1], dims[2], False # NHWC - if dims[0] in (1, -1): - dims = dims[1:] - - if len(dims) == 3: - if dims[0] == 3: - return dims[1], dims[2], True # CHW - if dims[-1] == 3: - return dims[0], dims[1], False # HWC - - return default_size[0], default_size[1], default_channels_first - -def _pick_preprocess_config( - model_path: str, input_height: int, input_width: int -) -> dict: - model_name = os.path.basename(model_path).lower() - if "efficientnet_v2_s" in model_name: - return { - "resize_size": 384, - "crop_height": 384, - "crop_width": 384, - "mean": np.array([0.485, 0.456, 0.406], dtype=np.float32), - "std": np.array([0.229, 0.224, 0.225], dtype=np.float32), - "resample": Image.BILINEAR, - } - if "efficientnet_v2_m" in model_name: - return { - "resize_size": 480, - "crop_height": 480, - "crop_width": 480, - "mean": np.array([0.485, 0.456, 0.406], dtype=np.float32), - "std": np.array([0.229, 0.224, 0.225], dtype=np.float32), - "resample": Image.BILINEAR, - } - if "efficientnet_v2_l" in model_name: - return { - "resize_size": 480, - "crop_height": 480, - "crop_width": 480, - "mean": np.array([0.5, 0.5, 0.5], dtype=np.float32), - "std": np.array([0.5, 0.5, 0.5], dtype=np.float32), - "resample": Image.BICUBIC, - } - if "efficientnet_b" in model_name: - return { - "resize_size": 600, - "crop_height": 600, - "crop_width": 600, - "mean": np.array([0.485, 0.456, 0.406], dtype=np.float32), - "std": np.array([0.229, 0.224, 0.225], dtype=np.float32), - "resample": Image.BICUBIC, - } - crop_height = input_height if input_height > 0 else 224 - crop_width = input_width if input_width > 0 else 224 - resize_size = int(round(max(crop_height, crop_width) / 0.875)) - return { - "resize_size": resize_size, - "crop_height": crop_height, - "crop_width": crop_width, - "mean": np.array([0.485, 0.456, 0.406], dtype=np.float32), - "std": np.array([0.229, 0.224, 0.225], dtype=np.float32), - "resample": Image.BILINEAR, - } - - def _pick_output_dtype(requirements: dict) -> np.dtype: supported = requirements.get("supported_types", []) for type_id in (1, 9, 3, 2): @@ -289,12 +158,34 @@ def _parse_convert_args(argv: list[str]): "efficientnet_v2_s", "efficientnet_v2_m", "efficientnet_v2_l", + "alexnet", + "convnext_tiny", + "convnext_small", + "convnext_base", + "convnext_large", + "vgg11", + "vgg11_bn", + "vgg13", + "vgg13_bn", + "vgg16", + "vgg16_bn", + "vgg19", + "vgg19_bn", "mobilenet_v2", + "mobilenet_v3_large", + "mobilenet_v3_small", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", + "shufflenet_v2_x0_5", + "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", + "squeezenet1_0", + "squeezenet1_1", + "inception_v3", ), default="mobilenet_v2", help="Torchvision model architecture.", @@ -344,71 +235,159 @@ def _init_torchvision_model(arch: str): model_specs = { "efficientnet_b0": ( torchvision.models.efficientnet_b0, - torchvision.models.EfficientNet_B0_Weights.IMAGENET1K_V1, + torchvision.models.EfficientNet_B0_Weights.DEFAULT, ), "efficientnet_b1": ( torchvision.models.efficientnet_b1, - torchvision.models.EfficientNet_B1_Weights.IMAGENET1K_V1, + torchvision.models.EfficientNet_B1_Weights.DEFAULT, ), "efficientnet_b2": ( torchvision.models.efficientnet_b2, - torchvision.models.EfficientNet_B2_Weights.IMAGENET1K_V1, + torchvision.models.EfficientNet_B2_Weights.DEFAULT, ), "efficientnet_b3": ( torchvision.models.efficientnet_b3, - torchvision.models.EfficientNet_B3_Weights.IMAGENET1K_V1, + torchvision.models.EfficientNet_B3_Weights.DEFAULT, ), "efficientnet_b4": ( torchvision.models.efficientnet_b4, - torchvision.models.EfficientNet_B4_Weights.IMAGENET1K_V1, + torchvision.models.EfficientNet_B4_Weights.DEFAULT, ), "efficientnet_b5": ( torchvision.models.efficientnet_b5, - torchvision.models.EfficientNet_B5_Weights.IMAGENET1K_V1, + torchvision.models.EfficientNet_B5_Weights.DEFAULT, ), "efficientnet_b6": ( torchvision.models.efficientnet_b6, - torchvision.models.EfficientNet_B6_Weights.IMAGENET1K_V1, + torchvision.models.EfficientNet_B6_Weights.DEFAULT, ), "efficientnet_b7": ( torchvision.models.efficientnet_b7, - torchvision.models.EfficientNet_B7_Weights.IMAGENET1K_V1, + torchvision.models.EfficientNet_B7_Weights.DEFAULT, ), "efficientnet_v2_s": ( torchvision.models.efficientnet_v2_s, - torchvision.models.EfficientNet_V2_S_Weights.IMAGENET1K_V1, + torchvision.models.EfficientNet_V2_S_Weights.DEFAULT, ), "efficientnet_v2_m": ( torchvision.models.efficientnet_v2_m, - torchvision.models.EfficientNet_V2_M_Weights.IMAGENET1K_V1, + torchvision.models.EfficientNet_V2_M_Weights.DEFAULT, ), "efficientnet_v2_l": ( torchvision.models.efficientnet_v2_l, - torchvision.models.EfficientNet_V2_L_Weights.IMAGENET1K_V1, + torchvision.models.EfficientNet_V2_L_Weights.DEFAULT, + ), + "alexnet": ( + torchvision.models.alexnet, + torchvision.models.AlexNet_Weights.DEFAULT, + ), + "convnext_tiny": ( + torchvision.models.convnext_tiny, + torchvision.models.ConvNeXt_Tiny_Weights.DEFAULT, + ), + "convnext_small": ( + torchvision.models.convnext_small, + torchvision.models.ConvNeXt_Small_Weights.DEFAULT, + ), + "convnext_base": ( + torchvision.models.convnext_base, + torchvision.models.ConvNeXt_Base_Weights.DEFAULT, + ), + "convnext_large": ( + torchvision.models.convnext_large, + torchvision.models.ConvNeXt_Large_Weights.DEFAULT, + ), + "vgg11": ( + torchvision.models.vgg11, + torchvision.models.VGG11_Weights.DEFAULT, + ), + "vgg11_bn": ( + torchvision.models.vgg11_bn, + torchvision.models.VGG11_BN_Weights.DEFAULT, + ), + "vgg13": ( + torchvision.models.vgg13, + torchvision.models.VGG13_Weights.DEFAULT, + ), + "vgg13_bn": ( + torchvision.models.vgg13_bn, + torchvision.models.VGG13_BN_Weights.DEFAULT, + ), + "vgg16": ( + torchvision.models.vgg16, + torchvision.models.VGG16_Weights.DEFAULT, + ), + "vgg16_bn": ( + torchvision.models.vgg16_bn, + torchvision.models.VGG16_BN_Weights.DEFAULT, + ), + "vgg19": ( + torchvision.models.vgg19, + torchvision.models.VGG19_Weights.DEFAULT, + ), + "vgg19_bn": ( + torchvision.models.vgg19_bn, + torchvision.models.VGG19_BN_Weights.DEFAULT, ), "mobilenet_v2": ( torchvision.models.mobilenet_v2, - torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V1, + torchvision.models.MobileNet_V2_Weights.DEFAULT, + ), + "mobilenet_v3_large": ( + torchvision.models.mobilenet_v3_large, + torchvision.models.MobileNet_V3_Large_Weights.DEFAULT, + ), + "mobilenet_v3_small": ( + torchvision.models.mobilenet_v3_small, + torchvision.models.MobileNet_V3_Small_Weights.DEFAULT, ), "resnet18": ( torchvision.models.resnet18, - torchvision.models.ResNet18_Weights.IMAGENET1K_V1, + torchvision.models.ResNet18_Weights.DEFAULT, ), "resnet34": ( torchvision.models.resnet34, - torchvision.models.ResNet34_Weights.IMAGENET1K_V1, + torchvision.models.ResNet34_Weights.DEFAULT, ), "resnet50": ( torchvision.models.resnet50, - torchvision.models.ResNet50_Weights.IMAGENET1K_V1, + torchvision.models.ResNet50_Weights.DEFAULT, ), "resnet101": ( torchvision.models.resnet101, - torchvision.models.ResNet101_Weights.IMAGENET1K_V1, + torchvision.models.ResNet101_Weights.DEFAULT, ), "resnet152": ( torchvision.models.resnet152, - torchvision.models.ResNet152_Weights.IMAGENET1K_V1, + torchvision.models.ResNet152_Weights.DEFAULT, + ), + "shufflenet_v2_x0_5": ( + torchvision.models.shufflenet_v2_x0_5, + torchvision.models.ShuffleNet_V2_X0_5_Weights.DEFAULT, + ), + "shufflenet_v2_x1_0": ( + torchvision.models.shufflenet_v2_x1_0, + torchvision.models.ShuffleNet_V2_X1_0_Weights.DEFAULT, + ), + "shufflenet_v2_x1_5": ( + torchvision.models.shufflenet_v2_x1_5, + torchvision.models.ShuffleNet_V2_X1_5_Weights.DEFAULT, + ), + "shufflenet_v2_x2_0": ( + torchvision.models.shufflenet_v2_x2_0, + torchvision.models.ShuffleNet_V2_X2_0_Weights.DEFAULT, + ), + "squeezenet1_0": ( + torchvision.models.squeezenet1_0, + torchvision.models.SqueezeNet1_0_Weights.DEFAULT, + ), + "squeezenet1_1": ( + torchvision.models.squeezenet1_1, + torchvision.models.SqueezeNet1_1_Weights.DEFAULT, + ), + "inception_v3": ( + torchvision.models.inception_v3, + torchvision.models.Inception_V3_Weights.DEFAULT, ), } if arch not in model_specs: @@ -428,6 +407,7 @@ def _convert_pretrained(argv: list[str]) -> int: from ai_edge_quantizer import quantizer, recipe # pylint: disable=import-outside-toplevel model, torch, input_height, input_width = _init_torchvision_model(args.arch) + model.eval() sample_inputs = (torch.randn(1, 3, input_height, input_width),) edge_model = litert_torch.convert(model, sample_inputs) if args.quantize: @@ -447,7 +427,7 @@ def _convert_pretrained(argv: list[str]) -> int: try: os.remove(tmp_path) except FileNotFoundError: - print(f"Warning: failed to delete temp file {tmp_path}: {e}", file=sys.stderr) + print(f"Warning: temp file already deleted: {tmp_path}", file=sys.stderr) else: edge_model.export(args.output) print(f"Saved LiteRT model to: {args.output}") @@ -467,9 +447,17 @@ def _classify(argv: list[str]) -> int: channels = 3 - input_height, input_width, channels_first = _infer_input_size(model, signature_index) - preprocess = _pick_preprocess_config(args.model, input_height, input_width) - image_array = _load_image( + input_height, input_width, channels_first = infer_input_size(model, signature_index) + preprocess = pick_preprocess_config(args.model, input_height, input_width) + layout = "NCHW" if channels_first else "NHWC" + print( + "Model input:", + f"{input_height}x{input_width}", + layout, + f"resize={preprocess['resize_size']}", + f"crop={preprocess['crop_height']}x{preprocess['crop_width']}", + ) + image_array = load_image( args.image, channels, preprocess["resize_size"], @@ -478,9 +466,8 @@ def _classify(argv: list[str]) -> int: preprocess["mean"], preprocess["std"], preprocess["resample"], + channels_first, ) - if not channels_first: - image_array = np.transpose(image_array, (1, 2, 0)) # CHW -> HWC input_buffers = model.create_input_buffers(signature_index) output_buffers = model.create_output_buffers(signature_index) input_buffers[0].write(image_array)