diff --git a/mjx/mujoco/mjx/_src/support.py b/mjx/mujoco/mjx/_src/support.py index 24031821e1..fb96bcf29b 100644 --- a/mjx/mujoco/mjx/_src/support.py +++ b/mjx/mujoco/mjx/_src/support.py @@ -328,7 +328,13 @@ def name2id( class BindModel(object): """Class holding the requested MJX Model and spec id for binding a spec to Model.""" - def __init__(self, model: Model, specs: Sequence[mujoco.MjStruct]): + def __init__( + self, + model: Model, + specs: Sequence[mujoco.MjStruct], + *, + squeeze: bool, + ): self.model = model self.prefix = '' ids = [] @@ -383,7 +389,7 @@ def __init__(self, model: Model, specs: Sequence[mujoco.MjStruct]): else: raise ValueError('invalid spec type') ids.append(spec.id) - if len(ids) == 1: + if len(ids) == 1 and squeeze: self.id = ids[0] else: self.id = ids @@ -405,18 +411,24 @@ def _bind_model( self: Model, obj: mujoco.MjStruct | Iterable[mujoco.MjStruct] ) -> BindModel: """Bind a Mujoco spec to an MJX Model.""" - if isinstance(obj, mujoco.MjStruct): + squeeze = isinstance(obj, mujoco.MjStruct) + if squeeze: obj = (obj,) else: obj = tuple(obj) - return BindModel(self, obj) + return BindModel(self, obj, squeeze=squeeze) class BindData(object): """Class holding the requested MJX Data and spec id for binding a spec to Data.""" def __init__( - self, data: Data, model: Model, specs: Sequence[mujoco.MjStruct] + self, + data: Data, + model: Model, + specs: Sequence[mujoco.MjStruct], + *, + squeeze: bool, ): self.data = data self.model = model @@ -453,7 +465,7 @@ def __init__( else: raise ValueError('invalid spec type') ids.append(spec.id) - if len(ids) == 1: + if len(ids) == 1 and squeeze: self.id = ids[0] else: self.id = ids @@ -571,11 +583,12 @@ def _bind_data( self: Data, model: Model, obj: mujoco.MjStruct | Iterable[mujoco.MjStruct] ) -> BindData: """Bind a Mujoco spec to an MJX Data.""" - if isinstance(obj, mujoco.MjStruct): + squeeze = isinstance(obj, mujoco.MjStruct) + if squeeze: obj = (obj,) else: obj = tuple(obj) - return BindData(self, model, obj) + return BindData(self, model, obj, squeeze=squeeze) Model.bind = _bind_model diff --git a/python/mujoco/__init__.py b/python/mujoco/__init__.py index 799a13a59b..c99eb9d877 100644 --- a/python/mujoco/__init__.py +++ b/python/mujoco/__init__.py @@ -19,7 +19,7 @@ import os import platform import subprocess -from typing import Any, IO, Union, Sequence +from typing import Any, IO, Union, Sequence, List from typing_extensions import TypeAlias import warnings import zipfile @@ -158,6 +158,28 @@ def from_zip(file: Union[str, IO[bytes]]) -> _specs.MjSpec: return _specs.MjSpec.from_string(xml_string, assets=assets) +def _extend_bind_items(items: List[Any], value: Any) -> None: + size = getattr(value, 'size', None) + if size is None: + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + if len(value) == 1: + items.extend(value) + else: + items.append(value) + return + items.append(value) + return + if size == 1: + try: + iter(value) + except TypeError: + items.append(value) + else: + items.extend(value) + else: + items.append(value) + + class _MjBindModel: """Wrapper for MjModel that allows binding multiple specs.""" @@ -167,7 +189,7 @@ def __init__(self, elements: Sequence[Any]): def __getattr__(self, key: str): items = [] for e in self.elements: - items.extend(getattr(e, key)) + _extend_bind_items(items, getattr(e, key)) return items def __setattr__(self, key: str, value: Any): @@ -183,7 +205,7 @@ def __init__(self, elements: Sequence[Any]): def __getattr__(self, key: str): items = [] for e in self.elements: - items.extend(getattr(e, key)) + _extend_bind_items(items, getattr(e, key)) return items def __setattr__(self, key: str, value: Any): diff --git a/python/mujoco/specs_test.py b/python/mujoco/specs_test.py index 8b472a3cd3..c3deb8bca8 100644 --- a/python/mujoco/specs_test.py +++ b/python/mujoco/specs_test.py @@ -1423,6 +1423,12 @@ def test_bind(self): np.testing.assert_array_equal(mj_model.bind(joint_box).qposadr, 7) np.testing.assert_array_equal(mj_data.bind(joints).qpos, [0, 0]) np.testing.assert_array_equal(mj_model.bind(joints).qposadr, [7, 8]) + np.testing.assert_array_equal( + mj_model.bind(spec.geoms[0:1]).size, mj_model.geom_size[0:1] + ) + np.testing.assert_array_equal( + mj_data.bind(spec.geoms[0:1]).xpos, mj_data.geom_xpos[0:1] + ) np.testing.assert_array_equal(mj_data.bind([]).qpos, []) np.testing.assert_array_equal(mj_model.bind([]).qposadr, []) mj_data.bind(joints).qpos = np.array([1, 2])