diff --git a/setup.py b/setup.py index 8c74e46f9..34a561a94 100644 --- a/setup.py +++ b/setup.py @@ -114,4 +114,51 @@ def find_version(*file_paths): install_requires=requirements, ext_modules=ext_modules, extras_require=extra_requirements -) +)# ------------------------------------------------------------- +# center_net_custom_patch.py +# Add support for custom classes in GluonCV CenterNet +# ------------------------------------------------------------- + +import mxnet as mx +from gluoncv.model_zoo.center_net import get_center_net +from gluoncv.model_zoo.model_store import get_model_file +from gluoncv.model_zoo import model_zoo + + +def center_net_resnet18_v1b_custom(pretrained=False, classes=None, ctx=mx.cpu(), + root='~/.mxnet/models', **kwargs): + """ + Custom CenterNet model that allows user-defined classes. + + Parameters + ---------- + pretrained : bool + Load pretrained COCO weights if True. + classes : list of str + Custom class names. + ctx : mx.context.Context + Device context (CPU/GPU). + root : str + Model storage root directory. + + Returns + ------- + net : gluon.HybridBlock + CenterNet model initialized for custom classes. + """ + # Initialize model with backbone + net = get_center_net('resnet18_v1b', pretrained_base=True, classes=classes, **kwargs) + + if pretrained: + # Load pretrained COCO weights + net.load_parameters(get_model_file('center_net_resnet18_v1b_coco', root=root), ctx=ctx) + # Reset for custom classes if provided + if classes is not None: + net.reset_class(classes) + + return net + + +# Optional: register the model with model_zoo so get_model() can find it +model_zoo._models['center_net_resnet18_v1b_custom'] = center_net_resnet18_v1b_custom +