Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,51 @@ def find_version(*file_paths):
install_requires=requirements,
ext_modules=ext_modules,
extras_require=extra_requirements
)
)# -------------------------------------------------------------
# center_net_custom_patch.py
# Add support for custom classes in GluonCV CenterNet
# -------------------------------------------------------------

import mxnet as mx
from gluoncv.model_zoo.center_net import get_center_net
from gluoncv.model_zoo.model_store import get_model_file
from gluoncv.model_zoo import model_zoo


def center_net_resnet18_v1b_custom(pretrained=False, classes=None, ctx=mx.cpu(),
root='~/.mxnet/models', **kwargs):
"""
Custom CenterNet model that allows user-defined classes.

Parameters
----------
pretrained : bool
Load pretrained COCO weights if True.
classes : list of str
Custom class names.
ctx : mx.context.Context
Device context (CPU/GPU).
root : str
Model storage root directory.

Returns
-------
net : gluon.HybridBlock
CenterNet model initialized for custom classes.
"""
# Initialize model with backbone
net = get_center_net('resnet18_v1b', pretrained_base=True, classes=classes, **kwargs)

if pretrained:
# Load pretrained COCO weights
net.load_parameters(get_model_file('center_net_resnet18_v1b_coco', root=root), ctx=ctx)
# Reset for custom classes if provided
if classes is not None:
net.reset_class(classes)

return net


# Optional: register the model with model_zoo so get_model() can find it
model_zoo._models['center_net_resnet18_v1b_custom'] = center_net_resnet18_v1b_custom