diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f2c0680571..1a207fd95b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -347,7 +347,7 @@ jobs: CHATMSG_AUTHOR_EMAIL: ${{ github.event.head_commit.author.email }} CHATMSG_COMMIT_MESSAGE: ${{ github.event.head_commit.message }} CHATMSG_JOB_ID: ${{ matrix.label }} - if: ${{ failure() && github.event_name == 'push' && env.GCHAT_API_URL != '' }} + if: failure() && github.ref_name == 'main' && github.event_name == 'push' && env.GCHAT_API_URL != '' run: bash ./.github/workflows/build_steps.sh notify_team_chat # This job quickly determines if MuJoCo Studio is broken. @@ -406,7 +406,7 @@ jobs: CHATMSG_AUTHOR_EMAIL: ${{ github.event.head_commit.author.email }} CHATMSG_COMMIT_MESSAGE: ${{ github.event.head_commit.message }} CHATMSG_JOB_ID: ${{ matrix.label }} - if: ${{ failure() && github.event_name == 'push' && env.GCHAT_API_URL != '' }} + if: failure() && github.ref_name == 'main' && github.event_name == 'push' && env.GCHAT_API_URL != '' run: bash ./.github/workflows/build_steps.sh notify_team_chat @@ -451,5 +451,5 @@ jobs: CHATMSG_AUTHOR_EMAIL: ${{ github.event.head_commit.author.email }} CHATMSG_COMMIT_MESSAGE: ${{ github.event.head_commit.message }} CHATMSG_JOB_ID: ${{ env.label }} - if: ${{ failure() && github.event_name == 'push' && env.GCHAT_API_URL != '' }} + if: failure() && github.ref_name == 'main' && github.event_name == 'push' && env.GCHAT_API_URL != '' run: bash ./.github/workflows/build_steps.sh notify_team_chat diff --git a/.github/workflows/build_steps.sh b/.github/workflows/build_steps.sh index 7286aca673..4c55e5d431 100755 --- a/.github/workflows/build_steps.sh +++ b/.github/workflows/build_steps.sh @@ -62,6 +62,12 @@ setup_emsdk() { git clone https://github.com/emscripten-core/emsdk.git ./emsdk/emsdk install 4.0.10 ./emsdk/emsdk activate 4.0.10 + # Force installing emscripten's typescript dependencies. This is a + # workaround for the github update to a newer typescript, which gives an + # error on the deprecated `--outFile` flag. + pushd emsdk/upstream/emscripten + npm i + popd } @@ -101,8 +107,6 @@ copy_plugins_posix() { mkdir -p ${TMPDIR}/mujoco_install/mujoco_plugin && cp lib/libactuator.* ${TMPDIR}/mujoco_install/mujoco_plugin && cp lib/libelasticity.* ${TMPDIR}/mujoco_install/mujoco_plugin && - cp lib/libobj_decoder.* ${TMPDIR}/mujoco_install/mujoco_plugin && - cp lib/libstl_decoder.* ${TMPDIR}/mujoco_install/mujoco_plugin && cp lib/libsensor.* ${TMPDIR}/mujoco_install/mujoco_plugin && cp lib/libsdf_plugin.* ${TMPDIR}/mujoco_install/mujoco_plugin } @@ -113,8 +117,6 @@ copy_plugins_window() { mkdir -p ${TMPDIR}/mujoco_install/mujoco_plugin && cp bin/Release/actuator.dll ${TMPDIR}/mujoco_install/mujoco_plugin && cp bin/Release/elasticity.dll ${TMPDIR}/mujoco_install/mujoco_plugin && - cp bin/Release/obj_decoder.dll ${TMPDIR}/mujoco_install/mujoco_plugin && - cp bin/Release/stl_decoder.dll ${TMPDIR}/mujoco_install/mujoco_plugin && cp bin/Release/sensor.dll ${TMPDIR}/mujoco_install/mujoco_plugin } diff --git a/cmake/ShellTests.cmake b/cmake/ShellTests.cmake index d2480a25d1..b26b165850 100644 --- a/cmake/ShellTests.cmake +++ b/cmake/ShellTests.cmake @@ -38,7 +38,6 @@ function(add_mujoco_shell_test TEST_NAME TARGET_BINARY) "CMAKE_SOURCE_DIR=${CMAKE_SOURCE_DIR}" "TARGET_BINARY=$" "TEST_TMPDIR=${TEST_TMPDIR}" - "MUJOCO_PLUGIN_DIR=$" ) if(WIN32) # Define the directory containing the mujoco DLL library so that it can be added to the PATH. diff --git a/doc/APIreference/functions.rst b/doc/APIreference/functions.rst index df6381f541..fa61555b04 100644 --- a/doc/APIreference/functions.rst +++ b/doc/APIreference/functions.rst @@ -4658,6 +4658,15 @@ Set actuator to muscle; return error if any.a Set actuator to active adhesion; return error if any. +.. _mjs_setToDCMotor: + +`mjs_setToDCMotor <#mjs_setToDCMotor>`__ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. mujoco-include:: mjs_setToDCMotor + +Set actuator to DC motor; return error if any. + .. _AddAssets: Assets diff --git a/doc/XMLreference.rst b/doc/XMLreference.rst index 7e90fa1488..2552f93219 100644 --- a/doc/XMLreference.rst +++ b/doc/XMLreference.rst @@ -6323,6 +6323,174 @@ This element has a subset of the common attributes and two custom attributes. to the target body. +.. _actuator-dcmotor: + +:el-prefix:`actuator/` |-| **dcmotor** |*| +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +This element creates a DC motor actuator. Note that :el:`dcmotor` is quite different from the :ref:`general actuation +model`. Unlike the general model where the components of force generation are independent affine functions +mapping from control to force, :el:`dcmotor` relies on highly coupled physical dynamics. See the `DC motor technical +note <_static/dcmotor.pdf>`__ for complete mathematical formulations and parameter semantics, but we include a few +important notes here: + +- Note that while :ref:`resistance`, :ref:`motorconst` and + :ref:`nominal` are each optional, some combination of them is required. + See Section 2.1 of the `technical note <_static/dcmotor.pdf>`__. +- The control :ref:`input` semantic is either the voltage applied to the motor terminals, or a + position or velocity target for a PID :ref:`controller`. +- Optional features include electrical dynamics (:ref:`inductance`), + :ref:`cogging torque`, :ref:`thermal resistance variation`, and + :ref:`LuGre` friction. + +The underlying :el:`general` attributes are set to the :el:`dcmotor` type, and their associated parameter arrays are +computed internally: + +========= ======= ========= ======== +Attribute Setting Attribute Setting +========= ======= ========= ======== +dyntype dcmotor dynprm computed +gaintype dcmotor gainprm computed +biastype dcmotor biasprm computed +========= ======= ========= ======== + +This element has the following custom attributes in addition to the common attributes: + +.. _actuator-dcmotor-name: + +.. _actuator-dcmotor-class: + +.. _actuator-dcmotor-group: + +.. _actuator-dcmotor-delay: + +.. _actuator-dcmotor-nsample: + +.. _actuator-dcmotor-interp: + +.. _actuator-dcmotor-ctrllimited: + +.. _actuator-dcmotor-ctrlrange: + +.. _actuator-dcmotor-lengthrange: + +.. _actuator-dcmotor-gear: + +.. _actuator-dcmotor-damping: + +.. _actuator-dcmotor-armature: + +.. _actuator-dcmotor-cranklength: + +.. _actuator-dcmotor-joint: + +.. _actuator-dcmotor-jointinparent: + +.. _actuator-dcmotor-tendon: + +.. _actuator-dcmotor-cranksite: + +.. _actuator-dcmotor-slidersite: + +.. _actuator-dcmotor-site: + +.. _actuator-dcmotor-refsite: + +.. _actuator-dcmotor-user: + +.. |actuator/dcmotor attrib list| replace:: + :at:`name`, :at:`class`, :at:`group`, :at:`nsample`, :at:`interp`, :at:`delay`, :at:`ctrllimited`, :at:`ctrlrange`, + :at:`lengthrange`, :at:`gear`, :at:`damping`, :at:`armature`, :at:`cranklength`, :at:`joint`, :at:`jointinparent`, + :at:`tendon`, :at:`cranksite`, :at:`slidersite`, :at:`site`, :at:`refsite`, :at:`user` + +|actuator/dcmotor attrib list| + Same as in actuator/ :ref:`general `. + +.. _actuator-dcmotor-resistance: + +:at:`resistance`: :at-val:`real, optional` + Terminal resistance :math:`R` in Ohm. (see `tech note <_static/dcmotor.pdf>`__ for details) + +.. _actuator-dcmotor-motorconst: + +:at:`motorconst`: :at-val:`real(2), optional` + Motor constants, defined as :at:`motorconst` = ":at-val:`Kt` :at-val:`Ke`" (N·m/A, equivalently V·s/rad). + :at-val:`Kt` is the torque constant and :at-val:`Ke` the back-EMF constant; they can differ when magnetic saturation + is present. If both are positive, the effective constant is :math:`K = \sqrt{K_t K_e}` (geometric mean). If only one + is positive, :math:`K` equals that value; a single value is interpreted as :math:`K_t = K_e`. If your datasheet gives + the speed constant :math:`K_v` in rad/(V·s), use :math:`K_e = 1/K_v`. (see `tech note <_static/dcmotor.pdf>`__ for + details) + +.. _actuator-dcmotor-nominal: + +:at:`nominal`: :at-val:`real(3), optional` + Nominal operating point, defined as :at:`nominal` = ":at-val:`voltage` :at-val:`stall_torque` + :at-val:`no_load_speed`". The compiler derives :math:`K =` :at-val:`voltage` / :at-val:`no_load_speed` and :math:`R = + K` · :at-val:`voltage` / :at-val:`stall_torque`. (see `tech note <_static/dcmotor.pdf>`__ for details) + +.. _actuator-dcmotor-inductance: + +:at:`inductance`: :at-val:`real(2), "0 0"` + Electrical dynamics, defined as :at:`inductance` = ":at-val:`L` :at-val:`timeconst`" (Henry, seconds). These are + alternative specifications: :at-val:`L` is the winding inductance and :at-val:`timeconst` :math:`= L/R` is the + electrical time constant. Specify one; if both are given, :at-val:`L` takes precedence. If both are 0 (the default), + no electrical dynamics are modeled and the current is computed algebraically. Adds one activation variable for + armature current. (see `tech note <_static/dcmotor.pdf>`__ for details) + +.. _actuator-dcmotor-thermal: + +:at:`thermal`: :at-val:`real(6), "0 0 0 0 0 0"` + Thermal model, defined as :at:`thermal` = ":at-val:`resistance` :at-val:`capacitance` :at-val:`timeconst` + :at-val:`tempcoef` :at-val:`reftemp` :at-val:`ambient`" (K/W, J/K, s, 1/K, °C, °C). The first three sub-values + specify the thermal time constant: :at-val:`timeconst` = :at-val:`resistance` :math:`\times` :at-val:`capacitance`. + Specify either :at-val:`timeconst` directly, or :at-val:`resistance` and :at-val:`capacitance`; if all three are + given, :at-val:`timeconst` takes precedence. If all are 0 (the default), thermal modeling is disabled. Adds one + activation variable for winding temperature. (see `tech note <_static/dcmotor.pdf>`__ for details) + +.. _actuator-dcmotor-saturation: + +:at:`saturation`: :at-val:`real(4), "0 0 0 0"` + Limits on the actuator, defined as :at:`saturation` = ":at-val:`torque` :at-val:`current` :at-val:`voltage` + :at-val:`current_rate`". :at-val:`torque` and :at-val:`current` are alternative specifications of the maximum + continuous torque: if :at-val:`current` is given, :at-val:`torque` :math:`= K \cdot` :at-val:`current`; if both are + given, :at-val:`torque` takes precedence. Sets :at:`forcerange` to [:math:`-\tau_{\max},\, \tau_{\max}`]. + :at-val:`voltage` sets the maximum voltage :math:`V_{\max}`. :at-val:`current_rate` sets the maximum rate of change + of current :math:`(di/dt)_{\max}` (requires :ref:`inductance`). A value of 0 (the + default) for any sub-value disables the respective limit. (see `tech note <_static/dcmotor.pdf>`__ for details) + +.. _actuator-dcmotor-cogging: + +:at:`cogging`: :at-val:`real(3), "0 0 0"` + Cogging torque, defined as :at:`cogging` = ":at-val:`amplitude` :at-val:`poles` :at-val:`phase`" (N·m, integer, rad). + Adds a position-dependent torque :math:`= \textsf{amplitude} \cdot \sin(\textsf{poles} \cdot \theta + + \textsf{phase})`. Disabled when :at-val:`amplitude` = 0 (the default). + (see `tech note <_static/dcmotor.pdf>`__ for details) + +.. _actuator-dcmotor-lugre: + +:at:`lugre`: :at-val:`real(6), "0 0 0 0 0 0"` + LuGre friction, defined as :at:`lugre` = ":at-val:`stiffness` :at-val:`damping` :at-val:`viscous` :at-val:`coulomb` + :at-val:`static` :at-val:`stribeck`" (N·m/rad, N·m·s/rad, N·m·s/rad, N·m, N·m, rad/s). Disabled when + :at-val:`stiffness` = 0 (the default). Adds one activation variable for bristle deflection. Note that the + :at-val:`viscous` coefficient is mapped directly to the actuator :ref:`damping` array + (specifically the linear term, :at-val:`damping[0]`). If both are specified, their values are summed. + (see `tech note <_static/dcmotor.pdf>`__ for details) + +.. _actuator-dcmotor-input: + +:at:`input`: :at-val:`[voltage, position, velocity], "voltage"` + Specifies the input signal semantics. In "voltage" mode, the control directly sets applied motor voltage. In + "position" or "velocity" modes, the PID :ref:`controller` uses the control as a + reference setpoint relative to the joint trajectory. (see `tech note <_static/dcmotor.pdf>`__ for details) + +.. _actuator-dcmotor-controller: + +:at:`controller`: :at-val:`real(5), "0 0 0 0 0"` + PID controller parameters, defined as :at:`controller` = ":at-val:`kp` :at-val:`ki` :at-val:`kd` + :at-val:`slewmax` :at-val:`Imax`". Depending on the :at:`input` mode, the controller stabilizes either position or + velocity. If the :at:`input` mode is voltage, the controller is ignored. A value of 0 (the default) disables the + respective feature: :at-val:`slewmax` = 0 means no slew-rate limiting, :at-val:`Imax` = 0 means no anti-windup + clamping. (see `tech note <_static/dcmotor.pdf>`__ for details) + .. _actuator-plugin: :el-prefix:`actuator/` |-| **plugin** |?| @@ -9887,6 +10055,57 @@ refsite, tendon, slidersite, cranksite. All :ref:`adhesion ` attributes are available here except: name, class, body. +.. _default-dcmotor: + +.. _default-dcmotor-ctrllimited: + +.. _default-dcmotor-ctrlrange: + +.. _default-dcmotor-gear: + +.. _default-dcmotor-damping: + +.. _default-dcmotor-armature: + +.. _default-dcmotor-cranklength: + +.. _default-dcmotor-user: + +.. _default-dcmotor-group: + +.. _default-dcmotor-delay: + +.. _default-dcmotor-nsample: + +.. _default-dcmotor-interp: + +.. _default-dcmotor-motorconst: + +.. _default-dcmotor-resistance: + +.. _default-dcmotor-nominal: + +.. _default-dcmotor-saturation: + +.. _default-dcmotor-inductance: + +.. _default-dcmotor-cogging: + +.. _default-dcmotor-controller: + +.. _default-dcmotor-input: + +.. _default-dcmotor-thermal: + +.. _default-dcmotor-lugre: + +:el-prefix:`default/` |-| **dcmotor** |?| +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +All :ref:`dcmotor ` attributes are available here except: name, class, joint, jointinparent, site, +refsite, tendon, slidersite, cranksite. + + .. _custom: **custom** |*| diff --git a/doc/XMLschema.rst b/doc/XMLschema.rst index 74f725618c..4dbc6c2991 100755 --- a/doc/XMLschema.rst +++ b/doc/XMLschema.rst @@ -2984,6 +2984,105 @@ :ref:`gain` + .. dropdown:: :ref:`dcmotor` |*| + + .. grid:: 2 3 4 4 + :gutter: 0 + + .. grid-item:: + :ref:`name` + + .. grid-item:: + :ref:`class` + + .. grid-item:: + :ref:`group` + + .. grid-item:: + :ref:`nsample` + + .. grid-item:: + :ref:`interp` + + .. grid-item:: + :ref:`delay` + + .. grid-item:: + :ref:`ctrllimited` + + .. grid-item:: + :ref:`ctrlrange` + + .. grid-item:: + :ref:`lengthrange` + + .. grid-item:: + :ref:`gear` + + .. grid-item:: + :ref:`damping` + + .. grid-item:: + :ref:`armature` + + .. grid-item:: + :ref:`cranklength` + + .. grid-item:: + :ref:`user` + + .. grid-item:: + :ref:`joint` + + .. grid-item:: + :ref:`jointinparent` + + .. grid-item:: + :ref:`tendon` + + .. grid-item:: + :ref:`slidersite` + + .. grid-item:: + :ref:`cranksite` + + .. grid-item:: + :ref:`site` + + .. grid-item:: + :ref:`refsite` + + .. grid-item:: + :ref:`motorconst` + + .. grid-item:: + :ref:`resistance` + + .. grid-item:: + :ref:`nominal` + + .. grid-item:: + :ref:`saturation` + + .. grid-item:: + :ref:`inductance` + + .. grid-item:: + :ref:`cogging` + + .. grid-item:: + :ref:`controller` + + .. grid-item:: + :ref:`thermal` + + .. grid-item:: + :ref:`lugre` + + .. grid-item:: + :ref:`input` + + .. dropdown:: :ref:`plugin` |*| .. grid:: 2 3 4 4 @@ -6146,6 +6245,75 @@ :ref:`delay` + .. dropdown:: :ref:`dcmotor` :octicon:`dot` + + .. grid:: 2 3 4 4 + :gutter: 0 + + .. grid-item:: + :ref:`ctrllimited` + + .. grid-item:: + :ref:`ctrlrange` + + .. grid-item:: + :ref:`gear` + + .. grid-item:: + :ref:`damping` + + .. grid-item:: + :ref:`armature` + + .. grid-item:: + :ref:`cranklength` + + .. grid-item:: + :ref:`user` + + .. grid-item:: + :ref:`group` + + .. grid-item:: + :ref:`nsample` + + .. grid-item:: + :ref:`interp` + + .. grid-item:: + :ref:`delay` + + .. grid-item:: + :ref:`motorconst` + + .. grid-item:: + :ref:`resistance` + + .. grid-item:: + :ref:`nominal` + + .. grid-item:: + :ref:`saturation` + + .. grid-item:: + :ref:`inductance` + + .. grid-item:: + :ref:`cogging` + + .. grid-item:: + :ref:`controller` + + .. grid-item:: + :ref:`input` + + .. grid-item:: + :ref:`thermal` + + .. grid-item:: + :ref:`lugre` + + .. dropdown:: :ref:`custom` |*| diff --git a/doc/_static/dcmotor.pdf b/doc/_static/dcmotor.pdf new file mode 100644 index 0000000000..caf8a36c1c Binary files /dev/null and b/doc/_static/dcmotor.pdf differ diff --git a/doc/changelog.rst b/doc/changelog.rst index 3931613049..1bd56f6752 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -8,6 +8,9 @@ Upcoming version (not yet released) General ^^^^^^^ +- Added the :ref:`dcmotor` actuator for modeling DC motors. Supports optional + electrical dynamics (inductance), cogging torque, thermal resistance variation, and LuGre friction. See the + `technical note <_static/dcmotor.pdf>`__ for more details. - Actuators with joint or tendon transmissions can now contribute :ref:`damping` and :ref:`armature` to their transmission target. These are applied during the passive force and inertia computations, respectively, and are scaled by gear\ :sup:`2` diff --git a/doc/dcmotor/buildpdf.sh b/doc/dcmotor/buildpdf.sh new file mode 100755 index 0000000000..ed73a053ab --- /dev/null +++ b/doc/dcmotor/buildpdf.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright 2026 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pdflatex -interaction=nonstopmode -jobname=dcmotor dcmotor.tex 2>&1 | grep -E '(Error|Output written)' && \ +bibtex dcmotor 2>&1 | grep -v '^$' && \ +pdflatex -interaction=nonstopmode -jobname=dcmotor dcmotor.tex 2>&1 | grep -E '(Error|Output written)' && \ +pdflatex -interaction=nonstopmode -jobname=dcmotor dcmotor.tex 2>&1 | grep -E '(Error|Output written)' && \ +rm -f *.{aux,log,out,bbl,blg} && \ +mv dcmotor.pdf ../_static/ diff --git a/doc/dcmotor/dcmotor.tex b/doc/dcmotor/dcmotor.tex new file mode 100644 index 0000000000..c60e08b433 --- /dev/null +++ b/doc/dcmotor/dcmotor.tex @@ -0,0 +1,1416 @@ +% Copyright 2026 DeepMind Technologies Limited +% +% Licensed under the Apache License, Version 2.0 (the "License"); +% you may not use this file except in compliance with the License. +% You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +% See the License for the specific language governing permissions and +% limitations under the License. + +\documentclass[10pt, a4paper, twocolumn]{article} +\usepackage[utf8]{inputenc} +\usepackage[T1]{fontenc} +\usepackage{roboto-mono} +\usepackage{relsize} +\let\oldtexttt\texttt +\renewcommand{\texttt}[1]{{\smaller\oldtexttt{#1}}} +\usepackage{amsmath, amssymb} +\usepackage{multicol} +\usepackage{geometry} +\geometry{margin=0.75in} +\usepackage{titlesec} +\titlespacing*{\section}{0pt}{1.5ex plus 0.5ex minus 0.2ex}{1ex plus 0.2ex} +\titlespacing*{\subsection}{0pt}{1.2ex plus 0.4ex minus 0.2ex}{0.8ex plus 0.2ex} +\setlength{\parskip}{0.4ex plus 0.1ex minus 0.1ex} +\usepackage{booktabs} +\usepackage{enumitem} +\setlist[itemize]{label=\scalebox{0.8}{$\bullet$}} +\usepackage{float} +\usepackage{titling} +\setlength{\droptitle}{-4em} +\usepackage{stfloats} +\usepackage{url} +\renewcommand{\UrlFont}{\small\ttfamily} +\usepackage{tikz} +\usepackage{pgfplots} +\pgfplotsset{compat=1.18} +\usepgfplotslibrary{fillbetween} +\usepackage{hyperref} +\hypersetup{colorlinks=true, linkcolor=blue, urlcolor=blue, citecolor=blue} +\usepackage{caption} +\usepackage{subcaption} +\captionsetup{font=footnotesize, labelfont=footnotesize} +\usepackage{xcolor} +\usepackage{listings} +\lstset{ + language=C, + basicstyle=\footnotesize\ttfamily, + keywordstyle=\bfseries\color{blue!70!black}, + commentstyle=\itshape\color{gray}, + stringstyle=\color{red!60!black}, + numbers=left, + numberstyle=\tiny\color{gray}, + numbersep=5pt, + frame=single, + framerule=0.4pt, + rulecolor=\color{gray!40}, + backgroundcolor=\color{gray!5}, + breaklines=true, + columns=fullflexible, + keepspaces=true, + showstringspaces=false, + tabsize=2, + xleftmargin=1.5em, + framexleftmargin=1.5em, + aboveskip=0.8em, + belowskip=0.5em, + morekeywords={mjtNum, mjModel, mjData, mjtByte}, +} + +\newcommand{\atR}{\texttt{resistance}} +\newcommand{\atK}{\texttt{motorconst}} +\newcommand{\atKt}{\texttt{motorconst:Kt}} +\newcommand{\atKe}{\texttt{motorconst:Ke}} +\newcommand{\atVM}{\texttt{nominal:voltage}} +\newcommand{\atSTALL}{\texttt{nominal:stall\_torque}} +\newcommand{\atNLS}{\texttt{nominal:no\_load\_speed}} +\newcommand{\atTMAX}{\texttt{saturation:torque}} +\newcommand{\atIMAX}{\texttt{saturation:current}} +\newcommand{\atVMAX}{\texttt{saturation:voltage}} +\newcommand{\atCRATE}{\texttt{saturation:current\_rate}} +\newcommand{\atKP}{\texttt{controller:kp}} +\newcommand{\atKI}{\texttt{controller:ki}} +\newcommand{\atKD}{\texttt{controller:kd}} +\newcommand{\atSLEW}{\texttt{controller:slewmax}} +\newcommand{\atIMAXINT}{\texttt{controller:Imax}} +\newcommand{\atL}{\texttt{inductance:L}} +\newcommand{\atTE}{\texttt{inductance:timeconst}} +\newcommand{\atCOGA}{\texttt{cogging:amplitude}} +\newcommand{\atCOGP}{\texttt{cogging:poles}} +\newcommand{\atCOGPH}{\texttt{cogging:phase}} +\newcommand{\atRT}{\texttt{thermal:resistance}} +\newcommand{\atTC}{\texttt{thermal:capacitance}} +\newcommand{\atTT}{\texttt{thermal:timeconst}} +\newcommand{\atALPHA}{\texttt{thermal:tempcoef}} +\newcommand{\atTREF}{\texttt{thermal:reftemp}} +\newcommand{\atTAMB}{\texttt{thermal:ambient}} +\newcommand{\atSIG}{\texttt{lugre:stiffness}} +\newcommand{\atSIGD}{\texttt{lugre:damping}} +\newcommand{\atTAUC}{\texttt{lugre:coulomb}} +\newcommand{\atTAUS}{\texttt{lugre:static}} +\newcommand{\atWS}{\texttt{lugre:stribeck}} +\newcommand{\atSIGV}{\texttt{lugre:viscous}} + +\title{MuJoCo DC Motor Model} +\author{Google DeepMind} +\date{} + +\begin{document} + +\maketitle + +\noindent We review DC motors and describe MuJoCo's \texttt{dcmotor} actuator. The equations are derived for brushed motors but apply equally to brushless ones, where electronic commutation reduces to an equivalent circuit. + +%============================================================================= +% BACKGROUND +%============================================================================= +\section{Background} +\label{sec:background} + +We use SI units throughout, but any coherent system of units applies. We assume motion is rotational; for linear motion replace radians with meters as required. + +% --------------------------------------------------------------------------- +% Electromagnetic Model +% --------------------------------------------------------------------------- +\subsection{Electromagnetic Model} +\label{sec:electromagnetics} + +The key electro-mechanical variables are + +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lll@{}} +\toprule +Symbol & Description & Units \\ +\midrule +$v$ & Applied voltage & Volt \\ +$i$ & Current & Ampere \\ +$\omega$ & Angular velocity & radian/second \\ +$\tau$ & Output torque & Newton $\cdot$ meter \\ +\bottomrule +\end{tabular} +\end{table} + +\noindent and the key constants are + +\begin{table}[H] +\centering +\footnotesize +\begin{tabular}{@{}lll@{}} +\toprule +Symbol & Description & Units \\ +\midrule +$R$ & Resistance & Ohm \\ +$K_t$ & Torque constant & Newton $\cdot$ meter/Ampere \\ +$K_e$ & Back-EMF constant & Volt $\cdot$ second/radian \\ +\bottomrule +\end{tabular} +\end{table} + +\noindent The quasi-static model~\cite{hughes2019, maxon_formulas, simscape_dcmotor} assumes instantaneous electrical dynamics: current and torque are direct functions of voltage and velocity. The constitutive equations are the voltage balance \eqref{eq:voltage} and the torque law \eqref{eq:torque_law}: +\begin{subequations} +\label{eq:motor_laws} +\begin{align} + v &= i \, R + K_e \, \omega \label{eq:voltage} \\ + \tau &= K_t i \label{eq:torque_law} +\end{align} +\end{subequations} + +\noindent Solving for current and substituting, we have + +\begin{equation} + \tau = \frac{K_t}{R} (v - K_e \, \omega) + \label{eq:torque_speed} +\end{equation} + +\noindent Output torque is proportional to the difference between applied and back-EMF voltage $v_{\text{back}} = K_e \, \omega$ (Figure~\ref{fig:torque_speed}). + +\paragraph{Electrical constants.} +Fundamentally, both $K_t$ and $K_e$ arise from the same physical quantity: the magnetic flux $\Phi$ of the coil. Faraday's law gives $v_{\text{back}} = \Phi \, \omega$ and the Lorentz force gives $\tau = \Phi \, i$, so in SI units: + +\begin{equation*} + K_e = K_t + \label{eq:ke_eq_kt} +\end{equation*} + +\noindent This can also be seen from energy conservation: $P_e = i \, (K_e \, \omega) = (K_t \, i) \, \omega = P_m$. + +\pagebreak +\noindent Note the dimensions match: + +\vspace*{-\abovedisplayskip} +\begin{equation*} + \frac{\text{Volt}}{\text{radian}/\text{second}} + = \frac{\text{Joule}}{\text{Coulomb}/\text{second}} + = \frac{\text{Newton} \cdot \text{meter}}{\text{Ampere}} +\end{equation*} + +\noindent If these constants are the same, why have both? Two reasons. First, datasheets typically use mixed units ($K_e$ in RPM/V, $K_t$ in mN$\cdot$m/A), giving different values for the same physical quantity. Second, $K_t$ and $K_e$ are measured differently: $K_t$ by locking the rotor and measuring torque per Ampere; $K_e$ by spinning the rotor and measuring open-circuit Volts per radian/second. In the first case, high currents can lead to magnetic field saturation in the core, causing the effective $K_t$ to drop below $K_e$. The equality $K_e = K_t$ thus assumes $\Phi$ independent of $i$. We make this assumption for now and use a single motor constant $K \equiv K_t = K_e$ throughout the remainder of this document and internally in MuJoCo, but see note at end of \S\ref{sec:dcmotor}. + +\begin{figure}[H] +\centering +\begin{tikzpicture} +\pgfmathsetmacro{\taus}{1.0} +\pgfmathsetmacro{\wz}{1.0} +\begin{axis}[ + width=0.9\columnwidth, height=0.55\columnwidth, + axis lines=left, + clip=false, + xlabel={$\omega$}, ylabel={$\tau$}, + xmin=0, xmax={\wz*1.15}, ymin=0, ymax={\taus*1.15}, + xtick={\wz}, xticklabels={$\omega_0$}, + ytick={\taus}, yticklabels={$\tau_0$}, + tick style={thick}, + every axis x label/.style={at={(ticklabel* cs:1)}, anchor=west}, + every axis y label/.style={at={(ticklabel* cs:1)}, anchor=south}, +] +\addplot[thick, blue!15] coordinates {(0,\taus*0.7) (\wz*0.7,0)}; +\addplot[thick, blue!25] coordinates {(0,\taus*0.8) (\wz*0.8,0)}; +\addplot[thick, blue!50] coordinates {(0,\taus*0.9) (\wz*0.9,0)}; +\addplot[thick, blue] coordinates {(0,\taus) (\wz,0)}; +\node[font=\scriptsize, text=gray, align=center] + at (axis cs: \wz*0.25, \taus*0.2) + {decreasing\\ voltage}; +\draw[->, thick, gray] (axis cs: \wz*0.4, \taus*0.55) + -- (axis cs: \wz*0.4, \taus*0.15); +\node[font=\scriptsize] at (axis cs: \wz*0.4, -0.08) + {Speed}; +\node[font=\scriptsize, rotate=90] at (axis cs: -0.04, \taus*0.4) + {Torque}; +\end{axis} +\end{tikzpicture} +\caption{Torque-speed relationship \eqref{eq:torque_speed} at fixed voltage. As voltage decreases, the maximum torque and speed decrease linearly.} +\label{fig:torque_speed} +\end{figure} + +\paragraph{Current Saturation.} +A maximum current rating $i_{\max}$ limits the output torque: +\begin{equation} + \tau = \text{clip}\!\left(\frac{K}{R}(v - K \, \omega),\; + \pm K \, i_{\max} \right) + \label{eq:saturation} +\end{equation} +where the maximum torque $\tau_{\max} = K \, i_{\max}$. The feasible torque-speed envelope forms a parallelogram: + +\begin{figure}[H] +\centering +\begin{tikzpicture} +\pgfmathsetmacro{\taus}{1.3} +\pgfmathsetmacro{\wz}{1.0} +\pgfmathsetmacro{\taumax}{0.7} +\pgfmathsetmacro{\slope}{\taus/\wz} +\pgfmathsetmacro{\wcu}{(\taus-\taumax)/\slope} +\pgfmathsetmacro{\wcl}{(\taus+\taumax)/\slope} +\pgfmathsetmacro{\wext}{1.5} +\pgfmathsetmacro{\dexthi}{\taus+\slope*\wext} +\pgfmathsetmacro{\dextlo}{\taus-\slope*\wext} +\begin{axis}[ + width=0.9\columnwidth, height=0.6\columnwidth, + axis lines=middle, + xlabel={$\omega$}, ylabel={$\tau$}, + xmin=-1.6, xmax=1.6, ymin=-1.6, ymax=1.6, + xtick={-\wz, \wz}, xticklabels={$-\omega_0$, {}}, + ytick={-\taus, \taus}, + yticklabels={$-\tau_0$, $\tau_0$}, + tick style={thick}, + every axis x label/.style={at={(ticklabel* cs:1)}, anchor=west}, + every axis y label/.style={at={(ticklabel* cs:1)}, anchor=south}, +] +\fill[blue, opacity=0.08] + (-\wcl, \taumax) -- (\wcu, \taumax) -- (\wcl, -\taumax) + -- (-\wcu, -\taumax) -- cycle; +\addplot[thick, dashed, gray] coordinates {(-\wext, \dexthi) (\wext, \dextlo)}; +\addplot[thick, dashed, gray] coordinates {(-\wext, -\dextlo) (\wext, -\dexthi)}; +\addplot[thick, dashed, gray] coordinates {(-1.55, \taumax) (1.55, \taumax)}; +\addplot[thick, dashed, gray] coordinates {(-1.55, -\taumax) (1.55, -\taumax)}; +\addplot[thick, blue] coordinates + {(-\wcl, \taumax) (\wcu, \taumax) (\wcl, -\taumax) (-\wcu, -\taumax) + (-\wcl, \taumax)}; +\node[font=\scriptsize, anchor=south] at (axis cs: \wz, \taumax) + {$\tau_{\max}$}; +\node[font=\scriptsize, anchor=north] at (axis cs: -\wz, -\taumax) + {$-\tau_{\max}$}; +\draw[thick, dashed, gray] (axis cs: \wz, 0) -- (axis cs: \wz, -\taumax); +\draw[thick, dashed, gray] (axis cs: -\wz, 0) -- (axis cs: -\wz, \taumax); +\node[font=\normalsize, anchor=south] at (axis cs: \wz, 0.05) {$\omega_0$}; +\end{axis} +\end{tikzpicture} +\caption{Torque-speed envelope with current saturation~\eqref{eq:saturation}.} +\label{fig:saturation} +\end{figure} + +\noindent Note that datasheets typically distinguish two current limits. The \emph{continuous} (or \emph{nominal}) current $i_{\max}$ is the thermal limit: the maximum current the motor can sustain indefinitely without exceeding its maximum winding temperature. The \emph{peak} current $i_{\text{peak}}$ is a higher short-term limit, typically 5--10$\times$ the continuous value, constrained by demagnetization or commutation limits. + +\begin{table}[H] +\centering +\footnotesize +\setlength{\tabcolsep}{3pt} +\renewcommand{\arraystretch}{1.2} +\begin{tabular}{@{}llll@{}} +\toprule +Symbol & Description & Condition & Formula/note\\ +\midrule +$\tau_0$ & Stall Torque & $\omega=0$ & $\tau_0 = Kv / R$ \\ +$\omega_0$ & No-Load Speed & $\tau_{\text{load}}=0$ & + $\omega_0 \approx v / K$ \\ +$\partial\omega / \partial\tau$ & Gradient & Slope & + $-R/K^2$ \\ +$i_{\max}$ & Maximum Current & Limit & Thermal limit \\ +$\tau_{\max}$ & Maximum Torque & Limit & $\tau_{\max} = K \, i_{\max}$ \\ +\bottomrule +\end{tabular} +\caption{Named constants derived from the motor equations.} +\label{tab:electromech_constants} +\end{table} + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Inductance} +\label{sec:inductance} + +Including the effects of winding inductance $L$ (Henry) means treating the current $i$ as a state variable: +\begin{equation} + v = L \, \frac{di}{dt} + i \, R + K \, \omega + \label{eq:inductance} +\end{equation} +The change in current is proportional to the voltage and negatively proportional to both the instantaneous current and the rotation velocity. The time constant of this ODE is $t_e = L/R$. If $t_e \ll \Delta t$ (the simulation timestep), the current equilibrates within a single step and the quasi-static approximation~\eqref{eq:torque_speed} is adequate. + +Motor drivers often impose a hard limit on $di/dt$ to protect windings and electronics, bounding the torque ramp rate to $K \cdot (di/dt)_{\max}$. + +% --------------------------------------------------------------------------- +% Mechanical Model +% --------------------------------------------------------------------------- +\subsection{Mechanical Model} +\label{sec:mechanical} + +Several purely mechanical phenomena affect the motor's behavior and the effective delivered torque. + +\paragraph{Mechanical losses.} +These reduce the net torque available at the shaft: $\tau_{\text{net}} = \tau_{\text{elec}} - \tau_{\text{loss}}$. +\begin{itemize} + \item \textbf{Coulomb:} Constant torque opposing rotation (dry friction). + $\tau_{\text{loss}} = \tau_c \, \text{sgn}(\omega)$. Discontinuous; already available in MuJoCo as \texttt{frictionloss}. + \item \textbf{Viscous:} Drag is a smooth function of speed $\tau_{\text{loss}} = b(\omega)$. The simplest model is linear, with drag proportional to speed $\tau_{\text{loss}} = B \, \omega$, but higher order terms may be needed for higher-fidelity models e.g., $\tau_{\text{loss}} = B_1 \, \omega + B_2 \, \omega |\omega| + B_3 \, \omega^3 + \dots$. +\end{itemize} + +\noindent Datasheets report the \emph{no-load current} $i_0$: the current drawn when spinning freely at no-load speed $\omega_0$. At steady state, the electromagnetic torque balances all mechanical losses: +\begin{equation} + K \, i_0 = \tau_c + B \, \omega_0 + \label{eq:noload} +\end{equation} +This provides one constraint on two unknowns ($\tau_c$ and $B$). Without additional data, the user must either assume one dominates or obtain friction measurements at multiple speeds. In MuJoCo terms, $\tau_c$ maps to \texttt{frictionloss} and $B$ to \texttt{damping}. + +\noindent Combining current saturation with both mechanical losses, the net torque is: +\begin{equation*} + \tau_{\text{net}} = \text{clip}\!\left( \frac{K}{R}(v - K \, \omega),\; + \pm K\, i_{\max} \right) - B \, \omega - \tau_c \, \text{sgn}(\omega) + \label{eq:net_torque} +\end{equation*} +The clipping applies to the electrical torque (current limit), while both friction terms are mechanical losses applied to the output post-clipping. Viscous drag $-B\omega$ tilts the envelope; Coulomb friction $-\tau_c\,\text{sgn}(\omega)$ shifts the right half ($\omega > 0$) down and the left half ($\omega < 0$) up, creating a $2\tau_c$ discontinuity at $\omega = 0$ (Figure~\ref{fig:drag_linear}). + +\begin{figure}[H] +\centering +\begin{subfigure}[t]{\columnwidth} +\centering +\begin{tikzpicture} +\pgfmathsetmacro{\taus}{1.3} +\pgfmathsetmacro{\wz}{1.0} +\pgfmathsetmacro{\taumax}{0.7} +\pgfmathsetmacro{\bvis}{0.3} +\pgfmathsetmacro{\tauf}{0.2} +\pgfmathsetmacro{\slope}{\taus/\wz} +\pgfmathsetmacro{\wcu}{(\taus-\taumax)/\slope} +\pgfmathsetmacro{\wcl}{(\taus+\taumax)/\slope} +\pgfmathsetmacro{\wext}{1.5} +\pgfmathsetmacro{\dragL}{\bvis*\wext} +\pgfmathsetmacro{\dragR}{-\bvis*\wext} +%% Right half vertices (ω ≥ 0): shifted down by τ_f +\pgfmathsetmacro{\Ra}{\taumax-\tauf} +\pgfmathsetmacro{\Rb}{\taumax-\bvis*\wcu-\tauf} +\pgfmathsetmacro{\Rc}{-\taumax-\bvis*\wcl-\tauf} +\pgfmathsetmacro{\Rd}{-\taumax-\tauf} +%% Left half vertices (ω ≤ 0): shifted up by τ_f +\pgfmathsetmacro{\La}{\taumax+\bvis*\wcl+\tauf} +\pgfmathsetmacro{\Lb}{\taumax+\tauf} +\pgfmathsetmacro{\Lc}{-\taumax+\tauf} +\pgfmathsetmacro{\Ld}{-\taumax+\bvis*\wcu+\tauf} +\begin{axis}[ + width=0.9\columnwidth, height=0.6\columnwidth, + axis lines=middle, + xlabel={$\omega$}, ylabel={$\tau$}, + xmin=-1.8, xmax=1.8, ymin=-1.8, ymax=1.8, + xtick=\empty, ytick=\empty, + tick style={thick}, + every axis x label/.style={at={(ticklabel* cs:1)}, anchor=west}, + every axis y label/.style={at={(ticklabel* cs:1)}, anchor=south}, +] +\fill[blue, opacity=0.08] + (0, \Ra) -- (\wcu, \Rb) -- (\wcl, \Rc) -- (0, \Rd) -- cycle; +\addplot[thick, blue] coordinates + {(0, \Ra) (\wcu, \Rb) (\wcl, \Rc) (0, \Rd)}; +\fill[blue, opacity=0.08] + (-\wcl, \La) -- (0, \Lb) -- (0, \Lc) -- (-\wcu, \Ld) -- cycle; +\addplot[thick, blue] coordinates + {(-\wcl, \La) (0, \Lb) (0, \Lc) (-\wcu, \Ld) (-\wcl, \La)}; +\addplot[thick, dashed, gray] coordinates {(0.01, {-\tauf-\bvis*0.01}) (\wext, {-\tauf+\dragR})}; +\addplot[thick, dashed, gray] coordinates {(-\wext, {\tauf+\dragL}) (-0.01, {\tauf+\bvis*0.01})}; +\node[font=\footnotesize, anchor=north west] at (axis cs: -1.75, -0.45) + {$-B\omega - \tau_c\,\text{sgn}(\omega)$}; +\draw[->, gray, thick] (axis cs: -1.4, -0.45) -- (axis cs: -1.3, {\tauf+\bvis}); +\pgfmathsetmacro{\gapmid}{(\Ra+\Lb)/2} +\draw[thick, <->, gray] (axis cs: 0.12, \Ra) -- (axis cs: 0.12, \Lb); +\node[font=\scriptsize, anchor=west] at (axis cs: 0.18, \gapmid) + {$2\tau_c$}; +\end{axis} +\end{tikzpicture} +\caption{Linear viscous drag and Coulomb friction.} +\label{fig:drag_linear} +\end{subfigure} + +\vspace{0.5em} + +\begin{subfigure}[t]{\columnwidth} +\centering +\begin{tikzpicture} +\pgfmathsetmacro{\taus}{1.3} +\pgfmathsetmacro{\wz}{1.0} +\pgfmathsetmacro{\taumax}{0.7} +\pgfmathsetmacro{\Bone}{0.15} +\pgfmathsetmacro{\Btwo}{0.35} +\pgfmathsetmacro{\slope}{\taus/\wz} +\pgfmathsetmacro{\wcl}{(\taus+\taumax)/\slope} +\pgfmathsetmacro{\wext}{1.7} +\begin{axis}[ + width=0.9\columnwidth, height=0.6\columnwidth, + axis lines=middle, + xlabel={$\omega$}, ylabel={$\tau$}, + xmin=-2.0, xmax=2.0, ymin=-2.0, ymax=2.0, + xtick=\empty, ytick=\empty, + tick style={thick}, + every axis x label/.style={at={(ticklabel* cs:1)}, anchor=west}, + every axis y label/.style={at={(ticklabel* cs:1)}, anchor=south}, + samples=200, +] +\addplot[name path=upper, thick, blue, domain=-\wcl:\wcl] + {min(\taus - \slope*x, \taumax) - \Bone*x - \Btwo*x*abs(x)}; +\addplot[name path=lower, thick, blue, domain=-\wcl:\wcl] + {max(-\taus - \slope*x, -\taumax) - \Bone*x - \Btwo*x*abs(x)}; +\addplot[blue, opacity=0.08] fill between[of=upper and lower]; +\addplot[thick, dashed, gray, domain=-\wext:\wext] + {-\Bone*x - \Btwo*x*abs(x)}; +\node[font=\footnotesize, anchor=north west] at (axis cs: -1.9, -0.5) + {$-b(\omega)$}; +\draw[->, gray, thick] (axis cs: -1.5, -0.5) -- (axis cs: -1.3, {0.1 + 0.35*1.3*1.3}); +\end{axis} +\end{tikzpicture} +\caption{Nonlinear viscous drag $b(\omega) = B_1\omega + B_2\omega|\omega|$, no friction.} +\label{fig:drag_nonlinear} +\end{subfigure} + +\caption{Torque-speed envelopes with mechanical losses. The dashed gray line shows the drag function; the shaded region is the achievable torque at each speed. Note that datasheet torque-speed curves typically plot the first quadrant only.} +\label{fig:drag} +\end{figure} + +\paragraph{Rotor Inertia and Gearing.} +Every DC motor datasheet lists the rotor inertia $J_r$ (kg$\cdot$m$^2$). When a gear train with ratio $N$ is attached, the effective inertia reflected to the output shaft is $J_{\text{eff}} = J_r N^2$~\cite{tedrake2024}. Note that real gearboxes also introduce efficiency losses (typically 70--90\%), which reduce the transmitted torque by a multiplicative factor $\eta$, approximated by effectively reducing the motor constant $K_{\text{eff}} = \eta K$. + +\paragraph{Cogging Torque.} +Brushless DC motors exhibit \emph{cogging torque}: a position-dependent torque ripple caused by the interaction between permanent magnets and stator slots. It can be modeled as a periodic bias: +\begin{equation} + \tau_{\text{cog}}(\theta) = A \sin(N_p \, \theta + \phi) + \label{eq:cogging} +\end{equation} +where $A$ is the amplitude, $N_p$ is the number of pole pairs times the number of slots per pole, and $\phi$ is a phase offset. Cogging torque is significant primarily at low speeds. Datasheets sometimes report peak cogging as a percentage of rated torque (typically 1--5\%). + +\begin{table}[H] +\centering +\footnotesize +\setlength{\tabcolsep}{3pt} +\renewcommand{\arraystretch}{1.2} +\begin{tabular}{@{}lll@{}} +\toprule +Symbol & Description & Formula / Note \\ +\midrule +$\tau_c$ & Coulomb friction & $\tau_c\,\text{sgn}(\omega)$ \\ +$B$ & Viscous drag (linear) & $B\,\omega$ \\ +$\omega_0$ & No-load speed & + $\omega_0 = v\,K / (K^2 + R\,B)$ \\ +$J_r$ & Rotor inertia & units: kg$\cdot$m$^2$ \\ +$N$ & Gear ratio & $J_{\text{eff}} = J_r N^2$ \\ +$\eta$ & Gearbox efficiency & $K' = \eta \, K$ \\ +$A$ & Cogging amplitude & $\tau_{\text{cog}} = A\sin(N_p\theta + \phi)$ \\ +$N_p$ & Cogging periodicity & poles $\times$ slots/pole \\ +$\phi$ & Cogging phase & offset \\ +\bottomrule +\end{tabular} +\caption{Named constants related to mechanical properties. Note that unlike in Table~\ref{tab:electromech_constants}, the non-approximate expression for $\omega_0$ takes into account the linear drag $B$ (assuming no high-order terms).} +\label{tab:key_constants} +\end{table} + +\paragraph{Backlash.} +Gearboxes introduce backlash: a small angular deadband where the motor can turn without moving the output shaft. Datasheets report this in arcminutes. MuJoCo supports backlash modeling via a dual-joint decomposition; \href{https://mujoco.readthedocs.io/en/stable/modeling.html#backlash}{see here} for details. + +% --------------------------------------------------------------------------- +% Thermal Model +% --------------------------------------------------------------------------- +\subsection{Thermal Model} +\label{sec:thermal} + +Winding temperature affects motor performance primarily through increased copper resistance, and can be modeled as a single lumped thermal state. The thermal constants are + +\begin{table}[H] +\centering +\footnotesize +\begin{tabular}{@{}lll@{}} +\toprule +Symbol & Description & Units \\ +\midrule +$R_T$ & Thermal resistance & Kelvin/Watt \\ +$C$ & Thermal capacitance & Joule/Kelvin \\ +$t_T = R_T C$ & Thermal time constant & second \\ +$\alpha$ & Resistance temp.\ coefficient & 1/Kelvin \\ +$T_0$ & Reference temperature & degree Celsius \\ +$T_a$ & Ambient temperature & degree Celsius \\ +\bottomrule +\end{tabular} +\caption{Thermal model constants. Units involving temperature differences use Kelvin (equivalent to Celsius for differences); absolute temperatures use degree Celsius, following datasheet convention.} +\end{table} + +\noindent Note that some manufacturers specify two thermal resistances: $R_{\text{th1}}$ (winding-to-housing) and $R_{\text{th2}}$ (housing-to-ambient), which sum to give the total winding-to-ambient thermal resistance $R_T = R_{\text{th1}} + R_{\text{th2}}$. The single-node model above uses $R_T$ directly. + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Lumped Thermal ODE} +\label{sec:thermal_ode} + +The winding temperature $T$ evolves according to a first-order lumped model driven by the power dissipation $P$ (Watt, detailed in \S\ref{sec:thermal_losses}): +\begin{equation} + \frac{\partial T}{\partial t} = \frac{1}{C} P - \frac{T - T_a}{t_T} + \label{eq:thermal_ode} +\end{equation} +where $t_T = R_T C$ is the thermal time constant. This produces exponential rise/decay toward a steady-state temperature $T_{ss} = T_a + R_T P$ (Figure~\ref{fig:thermal_response}). + +\begin{figure}[ht] +\centering +\begin{tikzpicture} +\pgfmathsetmacro{\Tss}{1.0} +\pgfmathsetmacro{\ttau}{1.0} +\pgfmathsetmacro{\xmax}{4.5} +\begin{axis}[ + width=0.9\columnwidth, height=0.45\columnwidth, + axis lines=left, + clip=false, + xlabel={$t$}, ylabel={$T - T_a$}, + xmin=0, xmax=\xmax, ymin=0, ymax={\Tss*1.25}, + xtick={\ttau}, xticklabels={$t_T$}, + ytick={{\Tss*(1-exp(-1))}, \Tss}, + yticklabels={$(1{-}1/e)\,R_T P$, $R_T P$}, + tick style={thick}, + every axis x label/.style={at={(ticklabel* cs:1)}, anchor=west}, + every axis y label/.style={at={(ticklabel* cs:1)}, anchor=south}, +] +\addplot[thick, blue, domain=0:\xmax, samples=100] + {\Tss*(1 - exp(-x/\ttau))}; +\addplot[thick, dashed, gray] coordinates {(0,\Tss) (\xmax,\Tss)}; +\draw[thick, dashed, gray] (axis cs:\ttau, 0) -- (axis cs:\ttau, {\Tss*(1-exp(-1))}); +\draw[thick, dashed, gray] (axis cs:0, {\Tss*(1-exp(-1))}) -- (axis cs:\ttau, {\Tss*(1-exp(-1))}); +\node[font=\scriptsize, anchor=south] at (axis cs:\xmax*0.5, \Tss) + {$T_{ss} = T_a + R_T P$}; +\end{axis} +\end{tikzpicture} +\caption{Temperature rise under constant power dissipation $P$. + At $t = t_T$, it reaches $(1-1/e) \approx 63\%$ of its steady-state value.} +\label{fig:thermal_response} +\end{figure} + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Losses} +\label{sec:thermal_losses} + +The dominant loss is copper (Joule) heating: +\begin{equation*} + P = i^2 R(T) + \label{eq:copper_loss} +\end{equation*} +Optionally, speed-dependent iron losses (eddy-current and hysteresis losses in the stator laminations) can be included~\cite{hughes2019}: +\begin{equation*} + P = i^2 R(T) + K_{\text{fe}} \omega^2 + \label{eq:total_loss} +\end{equation*} +The iron loss coefficient $K_{\text{fe}}$ is not typically listed on datasheets and must be identified from efficiency curves or manufacturer simulation tools. For most hobby and robotics motors, iron losses are small compared to copper losses and can be neglected. They become significant at high speeds in large industrial motors. + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Temperature-Dependent Resistance} +\label{sec:resistance_temperature} + +Copper resistance increases approximately linearly with temperature: +\begin{equation} + R(T) = R_0 \left(1 + \alpha (T - T_0)\right) + \label{eq:resistance_temperature} +\end{equation} +where $R_0$ is resistance at reference temperature $T_0$ and $\alpha \approx 0.0039 \, \text{K}^{-1}$ for copper. This is the dominant thermal feedback: as $T$ rises, $R$ increases, so for a given voltage the current $i = (v - K \omega) / R(T)$ drops, reducing torque. + +Note that $R(T)$ also increases heating for a given current ($P = i^2 R(T)$), creating mild positive feedback under current control. + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Magnet Flux Derating} +\label{sec:magnet_derating} + +Permanent magnet flux weakens with temperature, reducing $K$: +\begin{equation*} + K(T) = K_0 \left(1 + \alpha_m (T - T_0)\right) + \label{eq:kt_temperature} +\end{equation*} +with $\alpha_m < 0$ (motor-dependent). This effect is often small over normal operating ranges and can be ignored. + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Thermal Derating} +\label{sec:thermal_derating} + +Real actuators limit current as winding temperature approaches a maximum: +\begin{equation*} + i_{\max}(T) = \begin{cases} + i_{\text{rated}} & T \le T_1 \\ + i_{\text{safe}} + (i_{\text{rated}} - i_{\text{safe}}) \, s(T) & T_1 < T < T_2 \\ + i_{\text{safe}} & T \ge T_2 + \end{cases} +\end{equation*} +where $s(T)$ is a smooth interpolant between $T_1$ and $T_2$. This reduces the maximum available torque as the motor heats up. + +% --------------------------------------------------------------------------- +% Micro-Friction Models +% --------------------------------------------------------------------------- +\subsection{Micro-Friction Models} +\label{sec:micro_friction} + +Simple macroscopic friction models (Coulomb, viscous) cannot capture complex mechanical phenomena common in real motors with gear trains, such as pre-sliding hysteresis and stick-slip limit cycles. To capture these behaviors, a richer dynamic model is required. At the microscopic level, two surfaces in contact touch at many asperities which deform elastically under tangential load. This can be modeled as an average bristle deflection $z$, governed by a first-order ODE driven by the relative velocity $\omega$. Friction torque is then a function of $z$, $\dot{z}$, and $\omega$. + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Dahl Model} +\label{sec:dahl} + +The simplest stateful model~\cite{dahl68} treats friction as a rate-independent hysteresis operator derived from the stress-strain curve: +\begin{equation*} + \dot{z} = \omega - \frac{\sigma_0}{\tau_c} |\omega| \, z + \label{eq:dahl_state} +\end{equation*} +with output $\tau = -\sigma_0 z$. In steady state ($\dot{z}=0$), $z_{ss} = \tau_c \, \text{sgn}(\omega) / \sigma_0$ so $\tau_{ss} = -\tau_c \, \text{sgn}(\omega)$: pure Coulomb friction opposing motion. The two parameters are the bristle stiffness $\sigma_0$ (torque/radian) and the Coulomb friction torque $\tau_c$. +For small displacements the model is approximately linear ($\tau \approx -\sigma_0 \theta$), giving spring-like pre-sliding behavior with hysteresis during direction reversals (Figure~\ref{fig:hysteresis}). The Dahl model does not capture the Stribeck effect~\cite{stribeck1902} (the drop in friction at low velocity) and thus cannot predict stick-slip motion. + +\begin{figure}[H] +\centering +\begin{tikzpicture} +\begin{axis}[ + width=0.9\columnwidth, height=0.55\columnwidth, + axis lines=middle, + xlabel={$\theta$}, ylabel={$\tau$}, + xmin=-1.4, xmax=1.4, ymin=-1.4, ymax=1.4, + xtick=\empty, ytick=\empty, + every axis x label/.style={at={(ticklabel* cs:1)}, anchor=west}, + every axis y label/.style={at={(ticklabel* cs:1)}, anchor=south}, + clip=false, +] +\pgfmathsetmacro{\sig}{2.5} +\pgfmathsetmacro{\Fc}{1.0} +\pgfmathsetmacro{\xm}{1.0} +\pgfmathsetmacro{\ch}{(exp(\sig*\xm)+exp(-\sig*\xm))/2} +\addplot[thick, blue, domain=-\xm:\xm, samples=150, name path=lower] + {-\Fc*(1 - exp(-\sig*x)/\ch)}; +\addplot[thick, blue, domain=-\xm:\xm, samples=150, name path=upper] + {\Fc*(1 - exp(\sig*x)/\ch)}; +\addplot[blue, opacity=0.08] fill between[of=lower and upper]; +\draw[thick, dotted] (axis cs:-1.4, -\Fc) -- (axis cs:1.4, -\Fc) + node[right, font=\scriptsize] {$-\tau_c$}; +\draw[thick, dotted] (axis cs:-1.4, \Fc) -- (axis cs:1.4, \Fc) + node[right, font=\scriptsize] {$\tau_c$}; +\draw[->, thick, gray] (axis cs:0.05, -0.78) -- (axis cs:0.25, -0.83); +\draw[->, thick, gray] (axis cs:-0.05, 0.78) -- (axis cs:-0.25, 0.83); +\end{axis} +\end{tikzpicture} +\caption{Hysteresis loop: friction torque $\tau$ vs.\ displacement $\theta$ under slow + periodic loading (Dahl model). + The loop area represents energy dissipated per cycle. + Unlike memoryless Coulomb friction, the torque is continuous.} +\label{fig:hysteresis} +\end{figure} + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{LuGre Model} +\label{sec:lugre} + +The LuGre model~\cite{dewit95, lugre_revisited} extends Dahl by making the bristle saturation velocity-dependent and adding micro-damping and viscous terms: +\begin{subequations} +\label{eq:lugre} +\begin{align} + \dot{z} &= \omega - \sigma_0 \frac{|\omega|}{g(\omega)} z + \label{eq:lugre_state} \\ + \tau &= -(\sigma_0 z + \sigma_1 \dot{z} + \sigma_2 \omega) + \label{eq:lugre_force} +\end{align} +\end{subequations} +where the function $g(\omega)$ captures the Stribeck effect: +\begin{equation} + g(\omega) = \tau_c + (\tau_s - \tau_c) \, e^{-(\omega/\omega_s)^\gamma} + \label{eq:stribeck} +\end{equation} +with exponent $\gamma$ typically 1 or 2. In steady state, $\tau_{ss}(\omega) = -g(\omega) \, \text{sgn}(\omega) - \sigma_2 \omega$: the classic Stribeck curve (Figure~\ref{fig:stribeck}). + +\begin{figure}[H] +\centering +\begin{tikzpicture} +\begin{axis}[ + width=0.9\columnwidth, height=0.55\columnwidth, + axis lines=middle, + xlabel={$\omega$}, ylabel={$\tau_{ss}$}, + xmin=-3, xmax=3, ymin=-2.5, ymax=2.5, + xtick=\empty, ytick=\empty, + every axis x label/.style={at={(ticklabel* cs:1)}, anchor=west}, + every axis y label/.style={at={(ticklabel* cs:1)}, anchor=south}, + clip=false, +] +\pgfmathsetmacro{\Fc}{1.0} +\pgfmathsetmacro{\Fs}{1.8} +\pgfmathsetmacro{\vs}{0.5} +\pgfmathsetmacro{\sigtwo}{0.15} +\addplot[thick, blue, domain=0.01:3, samples=200] + {-(\Fc + (\Fs-\Fc)*exp(-(x/\vs)^2)) - \sigtwo*x}; +\addplot[thick, blue, domain=-3:-0.01, samples=200] + {(\Fc + (\Fs-\Fc)*exp(-(-x/\vs)^2)) - \sigtwo*x}; +\addplot[thick, dashed, gray, domain=-3:3, samples=2] {-\sigtwo*x}; +\draw[thick, dotted] (axis cs:0,-\Fs) -- (axis cs:3,-\Fs) + node[right, font=\scriptsize] {$-\tau_s$}; +\draw[thick, dotted] (axis cs:0,-\Fc) -- (axis cs:3,-\Fc) + node[right, font=\scriptsize] {$-\tau_c$}; +\draw[thick, dotted] (axis cs:0,\Fs) -- (axis cs:-3,\Fs) + node[left, font=\scriptsize] {$\tau_s$}; +\draw[thick, dotted] (axis cs:0,\Fc) -- (axis cs:-3,\Fc) + node[left, font=\scriptsize] {$\tau_c$}; +\node[font=\footnotesize, anchor=north east] at (axis cs:-0.5, -0.1) + {$-\sigma_2\omega$}; +\end{axis} +\end{tikzpicture} +\caption{Steady-state friction $\tau_{ss}(\omega) = -g(\omega)\,\text{sgn}(\omega) - \sigma_2 \omega$. + Stiction torque $\tau_s$ at $\omega\!=\!0$ drops to Coulomb level $\tau_c$ + over velocity scale $\omega_s$ (Stribeck effect). Friction opposes motion.} +\label{fig:stribeck} +\end{figure} + +\noindent The Dahl model is recovered by setting $g(\omega) = \tau_c$ and $\sigma_1 = \sigma_2 = 0$. Linearizing around $\omega = z = 0$ gives second-order dynamics $J\ddot{\theta} - (\sigma_1 + \sigma_2)\dot{\theta} - \sigma_0 \theta = \tau$ (applied torque): a spring-damper with natural frequency $\omega_n = \sqrt{\sigma_0/J}$, critically damped when $\sigma_1 = 2\sqrt{J\sigma_0}$. + +The LuGre model can be shown to be input-strictly-passive (the map $\omega \mapsto \tau$ dissipates energy) provided $\sigma_2 > \sigma_1 (\tau_s - \tau_c)/\tau_c$. This passivity condition limits $\sigma_1$ and can lead to underdamped micro-dynamics, motivating the following extension. + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Velocity-Dependent Damping} +\label{sec:vel_damping} + +The passivity constraint on $\sigma_1$ can be relaxed by making the micro-damping decrease with velocity: +\begin{equation*} + \sigma_1(\omega) = \bar{\sigma}_1\, e^{-(\omega/\omega_s)^\beta} + \label{eq:sigma1_vel} +\end{equation*} +This allows large damping in the stiction regime (good for numerical stability and physical fidelity) while satisfying passivity at higher velocities where $\sigma_1 \to 0$. + +Together with $\tau_c$ from \S\ref{sec:mechanical}, the LuGre model with velocity-dependent damping adds six parameters: +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lll@{}} +\toprule +Symbol & Description & Units \\ +\midrule +$\sigma_0$ & Bristle stiffness, pre-sliding slope & N$\cdot$m/rad \\ +$\bar{\sigma}_1$ & Peak bristle damping at $\omega = 0$ & N$\cdot$m$\cdot$s/rad \\ +$\sigma_2$ & Viscous damping coefficient & N$\cdot$m$\cdot$s/rad \\ +$\tau_s$ & Stiction torque, $\tau_s \ge \tau_c$ & N$\cdot$m \\ +$\omega_s$ & Stribeck velocity & rad/s \\ +$\beta$ & Damping decay exponent & dimensionless \\ +\bottomrule +\end{tabular} +\caption{Parameters of the LuGre friction model (\S\ref{sec:lugre}--\ref{sec:vel_damping}).} +\label{tab:lugre_params} +\end{table} + + +\newpage +%============================================================================= +% IMPLEMENTATION +%============================================================================= +\section{Implementation} +\label{sec:implementation} + +Here we describe MuJoCo's \texttt{dcmotor} actuator. Some scalars are grouped into vectors; we use a colon to denote such scalar sub-attributes, e.g.\ \texttt{cogging:phase} refers to the third element of the \texttt{cogging} attribute (see Tables \ref{tab:mjcf_attributes} and \ref{tab:cogging_impl}). + +\begin{table}[H] +\centering +\footnotesize +\begin{tabular}{@{}lll@{}} +\toprule +Attribute & Size & Description \\ +\midrule +\texttt{resistance} & 1 & Terminal resistance $R$ \\ +\texttt{motorconst} & 2 & Motor constants ($K_t, K_e$; see below) \\ +\texttt{nominal} & 3 & Nominal operating point ($v_n, \tau_0, \omega_0$) \\ +\texttt{inductance} & 2 & Electrical dynamics ($L, t_e$) \\ +\texttt{thermal} & 6 & Thermal model ($R_T, C, t_T, \alpha, T_0, T_a$) \\ +\texttt{saturation} & 4 & Limits ($\tau_{\max}, i_{\max}, v_{\max}, (di{/}dt)_{\max}$) \\ +\midrule +\texttt{cogging} & 3 & Cogging torque ($A, N_p, \phi$) \\ +\texttt{lugre} & 6 & LuGre friction ($\sigma_0, \sigma_1, \sigma_2, \tau_c, \tau_s, \omega_s$) \\ +\texttt{damping} & 3 & Viscous damping coefficients \\ +\texttt{armature} & 1 & Armature inertia \\ +\midrule +\texttt{input} & keyword & Mode (voltage/position/velocity) \\ +\texttt{controller} & 5 & Gains and slew ($k_p, k_i, k_d, s, I_{\max}$) \\ +\bottomrule +\end{tabular} +\caption{MJCF attributes for the \texttt{dcmotor} actuator, split into electrical, mechanical and control groupings.} +\label{tab:mjcf_attributes} +\end{table} + +% --------------------------------------------------------------------------- +% Stateless dcmotor +% --------------------------------------------------------------------------- +\subsection{Stateless DC Motor} +\label{sec:dcmotor} + +The output torque of the stateless motor follows Eq.~\eqref{eq:saturation}, mapping physical parameters to the underlying affine model. The three core parameters are the effective motor constant $K$, resistance $R$, and maximum torque $\tau_{\max}$. +They are stored in \texttt{mjModel} as follows: + +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lll@{}} +\toprule +Symbol & Description & \texttt{mjModel} storage \\ +\midrule +$R$ & Resistance & \texttt{gainprm[0]} \\ +$K$ & Effective motor constant & \texttt{gainprm[1]} \\ +$\tau_{\max}$ & Maximum torque & \texttt{forcerange} \\ +\bottomrule +\end{tabular} +\caption{Stateless DC motor core parameters.} +\end{table} + +\noindent The gain $G = K/R$ and back-EMF bias $-GK\omega$ are computed at runtime. Storing $R$ separately allows temperature-dependent resistance $R(T)$, Eq.~\eqref{eq:resistance_temperature}, to be applied. These three core parameters can be specified with a combination of eight sub-attributes + +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lll@{}} +\toprule +Attribute & Symbol & Units \\ +\midrule +\atR{} & $R$ & Ohm \\ +\atKt{} & $K_t$ & N$\cdot$m/A \\ +\atKe{} & $K_e$ & V$\cdot$s/rad \\ +\atVM{} & $v_n$ & Volt \\ +\atSTALL{} & $\tau_0\!=\!Kv_n/R$ & N$\cdot$m \\ +\atNLS{} & $\omega_0\!\approx\!v_n/K$ & rad/s \\ +\atTMAX{} & $\tau_{\max}$ & N$\cdot$m \\ +\atIMAX{} & $i_{\max}$ & Ampere \\ +\bottomrule +\end{tabular} +\caption{Stateless DC motor basic attributes.} +\end{table} + +\noindent The attribute \atK{} has two sub-attributes, \texttt{Kt} and \texttt{Ke}. If both are positive, $K = \sqrt{K_t K_e}$ (preserving power balance $K^2 = K_t K_e$). If only one is positive, $K$ equals that value. + +\pagebreak +\noindent The following attribute combinations are supported: +\begin{enumerate}[itemsep=2pt, parsep=0pt, topsep=2pt] + \item Effective motor constant $K$, one of: + \begin{itemize}[itemsep=1pt, parsep=0pt, topsep=1pt] + \item \atKt{} \emph{and/or} \atKe{} + \item \atNLS{} \emph{and} \atVM{} + \end{itemize} + \item Resistance $R$, one of: + \begin{itemize}[itemsep=1pt, parsep=0pt, topsep=1pt] + \item \atR{} + \item \atSTALL{} \emph{and} \atVM{} + \end{itemize} + \item Maximum torque $\tau_{\max}$, one of: + \begin{itemize}[itemsep=1pt, parsep=0pt, topsep=1pt] + \item \atTMAX{} + \item \atIMAX{} + \end{itemize} +\end{enumerate} + +\noindent \atIMAX{} corresponds to the continuous (thermal) current limit. Peak current behavior can be approximated with the thermal model (\S\ref{sec:temperature_impl}). + +\paragraph{Rotor Inertia and Gearing.} To model rotor inertia with a gear train, set the actuator's \texttt{armature} $= J_r$ and \texttt{gear} $= N$. Actuator-level \texttt{armature} automatically scales the inertia by $N^2$ to reflect $J_{\text{eff}}$ to the output shaft. + +\paragraph{Mechanical Drag.} The full torque-speed envelope applies viscous drag \emph{outside} the current clamp: +\begin{equation*} + \tau_{\text{net}} = \text{clip}\!\left( \frac{K}{R}(v - K \, \omega),\; + \pm \tau_{\max} \right) - b(\omega) + \label{eq:drag} +\end{equation*} +Actuator-level \texttt{damping} reproduces this post-clamp behavior. It accepts an array of polynomial drag coefficients (\texttt{damping[0]} $= B_1$, \texttt{damping[1]} $= B_2$, $\dots$). As with \texttt{armature}, actuator-level \texttt{damping} is scaled by $N^2$. + +\paragraph{Cogging Torque.} The magnetic torque ripple of Eq.~\eqref{eq:cogging} is modeled as a periodic bias added to the actuator force, where $\theta$ is the \texttt{actuator\_length}, i.e.\ the transmission-transformed joint angle. + +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lll@{}} +\toprule +Attribute & Symbol & Units \\ +\midrule +\atCOGA{} & $A$ & N$\cdot$m \\ +\atCOGP{} & $N_p$ & dimensionless \\ +\atCOGPH{} & $\phi$ & radian \\ +\bottomrule +\end{tabular} +\caption{Cogging torque attributes.} +\label{tab:cogging_impl} +\end{table} + +\paragraph{Mapping to Isaac Lab.} +Isaac Lab~\cite{isaaclab2025} implements the stateless DC motor model. Table~\ref{tab:mapping} maps its attributes to the constants defined in this document and to the \texttt{dcmotor} attributes. + +\begin{table}[H] +\centering +\footnotesize +\begin{tabular}{@{}lll@{}} +\toprule +Symbol & MuJoCo & Isaac Lab \\ +\midrule +$\tau_0$ & \atSTALL{} & \texttt{saturation\_effort} \\ +$\omega_0$ & \atNLS{} & \texttt{velocity\_limit} \\ +$\tau_{\max}$ & \atTMAX{} & \texttt{effort\_limit} \\ +\bottomrule +\end{tabular} +\caption{Mapping of attributes to Isaac Lab.} +\label{tab:mapping} +\end{table} + +\noindent Isaac Lab does not expose electrical parameters. To reproduce the same torque-speed envelope, set \atVM{} to any positive value (e.g.,~\texttt{1}) and \texttt{ctrlrange} to $\pm$\,that value (e.g., \texttt{"-1 1"}). + +\paragraph{Gearbox Efficiency.} Gearbox efficiency $\eta$ is not a separate attribute. To account for transmission losses, reduce the motor constant: $K \rightarrow \eta K$. This correctly reduces forward torque transmission. + +\paragraph{Computed parameters.} +Several derived quantities that appear on datasheets can be computed and used to cross-check the parameterization. The torque-speed gradient $\partial\omega/\partial\tau = -R/K^2$ gives the slope of the torque-speed line (Table~\ref{tab:electromech_constants}). The mechanical time constant $t_m = R\,J/K^2$ is the time for the motor to reach 63\% of its no-load speed under a voltage step, where $J$ is the rotor inertia (\texttt{armature}). The nominal (continuous) torque is $\tau_n = K \cdot i_{\max}$. The no-load current $i_0$ can be computed from Eq.~\eqref{eq:noload} given known friction parameters. See Table~\ref{tab:datasheet}. + +\paragraph{Not modeled:} +Nonlinear torque constant $K_t(i)$. Separate $K_t$ and $K_e$ values are accepted via \atKt{} and \atKe{} but collapsed to a single effective $K = \sqrt{K_t K_e}$. + + +% --------------------------------------------------------------------------- +% Stateful Current +% --------------------------------------------------------------------------- +\subsection{Stateful Current} +\label{sec:current_impl} + +A winding current state variable governed by Eq.~\eqref{eq:inductance} is added if the electrical time constant $t_e > 0$ (derived from inductance $L > 0$ or specified directly). When enabled, the state is integrated by \texttt{mjDYN\_DCMOTOR}, and the gain switches from $K/R$ (stateless) to $K$ (stateful). +The time constant $t_e$ can be determined by either: + +\begin{itemize}[itemsep=0pt, parsep=0pt, topsep=2pt] + \item \atTE{} + \item \atL{} \emph{and} \atR{} (via $t_e = L/R$) +\end{itemize} + +\paragraph{Current rate limiting.} When the sub-attribute \atCRATE{} is set ($(di/dt)_{\max} > 0$) and the current state is enabled ($t_e > 0$), the rate of change of current is clamped: +\begin{equation*} + \frac{di}{dt} \leftarrow \text{clip}\!\left(\frac{di}{dt},\; \pm(di/dt)_{\max}\right) +\end{equation*} +This limits the torque ramp rate to $K \cdot (di/dt)_{\max}$ (N$\cdot$m/s) without requiring any additional state variables, since the current $i$ is already an activation variable and we are simply clamping its rate of change. This attribute has no effect when $t_e = 0$ (stateless current). + +\begin{table}[H] +\centering +\footnotesize +\begin{tabular}{@{}lll@{}} +\toprule +Attribute & Symbol & Units \\ +\midrule +\atL{} & $L$ & Henry \\ +\atTE{} & $t_e\!=\!L/R$ & second \\ +\atCRATE{} & $(di{/}dt)_{\max}$ & Ampere/second \\ +\bottomrule +\end{tabular} +\caption{Stateful current attributes.} +\end{table} + +% --------------------------------------------------------------------------- +% Temperature +% --------------------------------------------------------------------------- +\subsection{Temperature} +\label{sec:temperature_impl} + +A winding temperature state governed by the lumped ODE~\eqref{eq:thermal_ode} is added if any of the thermal attributes ($R_T, C, t_T$) are specified. The state $T$ is the temperature rise above ambient ($T = T_{\text{winding}} - T_a$), so the absolute temperature is $T + T_a$. Temperature modifies the winding resistance via Eq.~\eqref{eq:resistance_temperature}, which feeds back into the motor equation: higher temperature increases resistance, leading to reduced current for a given voltage. + +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lll@{}} +\toprule +Attribute & Symbol & Units \\ +\midrule +\atRT{} & $R_T$ & K/W \\ +\atTC{} & $C$ & J/K \\ +\atTT{} & $t_T\!=\!R_T C$ & s \\ +\atALPHA{} & $\alpha$ & 1/K \\ +\atTREF{} & $T_0$ & \textdegree C \\ +\atTAMB{} & $T_a$ & \textdegree C \\ +\bottomrule +\end{tabular} +\caption{Thermal model attributes.} +\end{table} + +\noindent The time constant $t_T$ can be determined by either: +\begin{itemize}[itemsep=0pt, parsep=0pt, topsep=2pt] + \item \atTT{} + \item \atRT{} \emph{and} \atTC{} +\end{itemize} + +\paragraph{Not modeled:} +Iron losses (\S\ref{sec:thermal_losses}), magnet flux derating (\S\ref{sec:magnet_derating}), and thermal current derating (\S\ref{sec:thermal_derating}). Only copper losses ($i^2 R$) drive the thermal model; $K$ is treated as temperature-independent. + +% --------------------------------------------------------------------------- +% Stateful Friction +% --------------------------------------------------------------------------- +\subsection{Stateful Friction} +\label{sec:friction_impl} + +A bristle deflection state governed by the LuGre model (\S\ref{sec:lugre}) is added if the bristle stiffness $\sigma_0 > 0$. + +The Stribeck function $g(\omega)$, Eq.~\eqref{eq:stribeck}, determines velocity-dependent friction, and the friction force is given by Eq.~\eqref{eq:lugre_force}. The bristle state is integrated using the exact ZOH scheme~\eqref{eq:zoh}. The viscous term $\sigma_2 \omega$ is mapped directly to the standard \texttt{actuator\_damping} attribute to leverage MuJoCo's implicit integration, while maintaining the $\sigma_2$ \texttt{lugre} sub-attribute for convenience. +\paragraph{Integration.} +The bristle stiffness $\sigma_0$ is typically very large ($10^5$--$10^6$ N$\cdot$m/rad), creating a stiff ODE. At constant velocity, the state equation~\eqref{eq:lugre_state} has the form $\dot{z} = a z + b \omega$ where $a = -\sigma_0 |\omega| / g(\omega)$ and $b = 1$. Euler integration is unstable unless $|1 + a \Delta t| < 1$, requiring impractically small timesteps ($\Delta t < 2g(\omega)/(\sigma_0 |\omega|)$, on the order of microseconds). +Under a zero-order hold assumption ($\omega$ constant over the timestep), the linear ODE $\dot{z} = az + b\omega$ can be solved exactly: +\begin{equation} + z_{k+1} = e^{a \Delta t} z_k + + \frac{b(e^{a \Delta t} - 1)}{a} \, \omega + \label{eq:zoh} +\end{equation} +reducing to $z_{k+1} = z_k + b\omega\Delta t$ in the limit $a \to 0$. This integration is unconditionally stable for any $\Delta t$. + +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lll@{}} +\toprule +Attribute & Symbol & Units \\ +\midrule +\atSIG{} & $\sigma_0$ & N$\cdot$m/rad \\ +\atSIGD{} & $\sigma_1$ & N$\cdot$m$\cdot$s/rad \\ +\atSIGV{} & $\sigma_2$ & N$\cdot$m$\cdot$s/rad \\ +\atTAUC{} & $\tau_c$ & N$\cdot$m \\ +\atTAUS{} & $\tau_s$ & N$\cdot$m \\ +\atWS{} & $\omega_s$ & rad/s \\ +\bottomrule +\end{tabular} +\caption{LuGre friction attributes.} +\end{table} + +\noindent\textbf{Not modeled:} +Velocity-dependent bristle damping $\sigma_1(\omega)$ (\S\ref{sec:vel_damping}), a constant $\sigma_1$ is used. The Stribeck exponent is not exposed and fixed at $\gamma = 2$. + +\newpage +% --------------------------------------------------------------------------- +% PID Controller +% --------------------------------------------------------------------------- +\subsection{PID Controller} +\label{sec:controller} + +Many actuators embed an on-board controller computing drive voltage from position or velocity commands. To model such actuators, \texttt{dcmotor} supports an optional controller layer upstream of the motor physics. Two attributes control this behavior: + +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lll@{}} +\toprule +Attribute & Type & Description \\ +\midrule +\texttt{input} & keyword & \texttt{voltage}, \texttt{position}, \texttt{velocity} \\ +\texttt{controller} & vector & Gains (mode-dependent) \\ +\bottomrule +\end{tabular} +\caption{Controller attributes. Default \texttt{input} is \texttt{voltage}.} +\label{tab:controller_attributes} +\end{table} + +\noindent Unlike the motor parameters in Table~\ref{tab:datasheet}, controller gains are user-specified firmware settings. Gains are in {\em voltage-space} (e.g., $k_p$ in V/rad) since the output is a voltage $v$. To convert from physical torque-space (N$\cdot$m/rad), multiply by $R/K$. + +The controller computes a target voltage $v$ from the \texttt{ctrl} command. All motor physics --- cogging, saturation, friction, etc. --- apply identically downstream of $v$. The \texttt{input} attribute selects the controller: + +\begin{figure}[H] +\centering +\begin{tikzpicture}[ + block/.style={draw, rounded corners=2pt, minimum height=1.6em, + font=\scriptsize, fill=blue!5}, + mode/.style={font=\scriptsize, text=blue!70!black}, + arr/.style={->, thick, >=stealth}, + every node/.style={inner sep=2pt}, +] +% ctrl input +\node[font=\small] (ctrl) at (0, 3.5) {Input $u = {}$\texttt{ctrl}}; + +% Mode selector box +\node[block, minimum width=5.5cm, minimum height=6.5em, align=center] + (sel) at (0, 1.8) {}; +\node[font=\footnotesize\bfseries, anchor=north] at (0, 2.7) + {Controller mode}; +\node[mode] at (0, 1.45) {$\begin{aligned} + \texttt{voltage:}\quad v &= u \\[2pt] + \texttt{position:}\quad v &= k_p(u\!-\!\theta) + k_i x_I - k_d\dot\theta \\[2pt] + \texttt{velocity:}\quad v &= k_p(u\!-\!\dot\theta) + k_i(x_I\!-\!\theta) +\end{aligned}$}; + +% arrow ctrl to mode +\draw[arr] (ctrl.south) -- (sel.north); + +% Motor block +\node[block, font=\footnotesize\bfseries, minimum width=5.5cm, minimum height=2.2em, align=center] + (motor) at (0, -0.8) {DC Motor physics}; + +% single arrow with v label +\draw[arr] (sel.south) -- (motor.north) + node[midway, fill=white, font=\small, inner sep=2pt] {Voltage $v$}; + +% output +\node[font=\small] (tau) at (0, -1.8) {Torque $\tau$}; +\draw[arr] (motor.south) -- (tau.north); + +\end{tikzpicture} +\caption{Controller pipeline. The \texttt{input} attribute selects how $v$ is derived from \texttt{ctrl}; motor physics is identical downstream.} +\label{fig:controller_pipeline} +\end{figure} + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Position Mode} +\label{sec:position_mode} + +When \texttt{input="position"}, the user command $u = {}$\texttt{ctrl} is a target position, yielding voltage: +\begin{equation} + v = k_p \, (u - \theta) + k_i \, x_I - k_d \, \dot\theta + \label{eq:position_mode} +\end{equation} +where $\theta$ is the actuator length, $\dot\theta \equiv \omega$ is the actuator velocity, $x_I$ is the integral of position error, and $k_p$, $k_i$, $k_d$ are the proportional, integral, and derivative gains. + +\noindent The signs in~\eqref{eq:position_mode} follow MuJoCo convention: $k_p > 0$ drives toward the target, $k_d > 0$ provides damping (opposing velocity), and $k_i > 0$ reduces steady-state error. + +\paragraph{Integral state.} When $k_i > 0$, one additional activation state $x_I$ is allocated, governed by: +\begin{equation*} + \dot{x}_I = u - \theta + \label{eq:position_integral} +\end{equation*} +When $k_i = 0$, no integral state is added and the controller reduces to PD. + +\paragraph{Effective torque.} Substituting~\eqref{eq:position_mode} into the stateless torque equation~\eqref{eq:torque_speed}: +\begin{equation*} + \tau = \frac{K}{R} v - \frac{K^2}{R}\dot\theta + = \underbrace{\frac{K k_p}{R}}_{\text{stiffness}} (u - \theta) + + \frac{K k_i}{R} x_I + - \underbrace{\frac{K(K + k_d)}{R}}_{\text{damping}} \dot\theta + \label{eq:position_torque} +\end{equation*} +Note that the motor's back-EMF term $K^2\dot\theta/R$ contributes {\em additional damping} beyond the controller $k_d$ term. Even with $k_d\!=\!0$, the motor provides natural damping $K^2/R$. The computed $v$ is subject to voltage saturation (\S\ref{sec:voltage_saturation}). + +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lll@{}} +\toprule +Attribute & Symbol & Units \\ +\midrule +\atKP{} & $k_p$ & V/rad \\ +\atKI{} & $k_i$ & V/(rad$\cdot$s) \\ +\atKD{} & $k_d$ & V$\cdot$s/rad \\ +\bottomrule +\end{tabular} +\caption{Position mode controller gains.} +\label{tab:position_params} +\end{table} + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Velocity Mode} +\label{sec:velocity_mode} + +When \texttt{input="velocity"}, the user command $u = {}$\texttt{ctrl} is a target velocity, and $k_p$, $k_i$ are the proportional and integral gains. +\begin{equation} + v = k_p \, (u - \dot\theta) + k_i \, (x_I - \theta) + \label{eq:velocity_mode} +\end{equation} + +\paragraph{Integral state.} When $k_i > 0$, one additional activation state $x_I$ is allocated, governed by the integrator: +\begin{equation*} + \dot{x}_I = u + \label{eq:velocity_integral} +\end{equation*} +The term $k_i(x_I - \theta)$ then tracks a target position $x_I$ advancing at the commanded velocity $u$. This matches MuJoCo's \texttt{intvelocity} actuator behavior. + +When $k_i = 0$, no integral state is added and the controller provides pure velocity feedback. The computed $v$ is subject to voltage saturation (\S\ref{sec:voltage_saturation}). + +\paragraph{Effective torque.} Substituting~\eqref{eq:velocity_mode} into~\eqref{eq:torque_speed}: +\begin{equation*} + \tau = \underbrace{\frac{K k_i}{R}}_{\text{stiffness}} (x_I - \theta) + - \underbrace{\frac{K(K + k_p)}{R}}_{\text{damping}} \dot\theta + + \frac{K k_p}{R} u + \label{eq:velocity_torque} +\end{equation*} +Note the role swap compared to position mode: $k_i$ provides stiffness (position tracking to $x_I$) while $k_p$ adds damping alongside the motor's natural back-EMF damping $K^2/R$. + +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lll@{}} +\toprule +Attribute & Symbol & Units \\ +\midrule +\atKP{} & $k_p$ & V$\cdot$s/rad \\ +\atKI{} & $k_i$ & V/rad \\ +\bottomrule +\end{tabular} +\caption{Velocity mode controller gains.} +\label{tab:velocity_params} +\end{table} + +\pagebreak + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Setpoint Slew Rate} +\label{sec:setpoint_slew} + +The user command $u = {}$\texttt{ctrl} can change discontinuously between timesteps. When \atSLEW{} is set ($s > 0$), the effective setpoint is rate-limited: +\begin{equation*} + u \leftarrow \text{clip}(u, \; u_{\text{prev}} \pm s \cdot \Delta t) +\end{equation*} +where $u_{\text{prev}}$ is the previous effective setpoint and $\Delta t$ is the timestep. This smoothly ramps the reference trajectory instead of allowing instantaneous jumps. + +\paragraph{State variable.} When $s > 0$, one activation state $u_{\text{prev}}$ is allocated, and updated each step to $u$ (post-clamping). + +\paragraph{Units.} The slew rate $s$ has mode-dependent units: rad/s for position mode (limiting setpoint velocity), rad/s\textsuperscript{2} for velocity mode (limiting setpoint acceleration), and V/s for voltage mode. + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Anti-windup} +\label{sec:anti_windup} + +When $k_i > 0$, the integrator state $x_I$ provides steady-state error correction. However, sustained saturation or large setpoint changes can cause $x_I$ to grow excessively, leading to overshoot (integral windup). To prevent this, when \atIMAXINT{} is set ($I_{\max} > 0$), the state is bounded each step: +\begin{equation*} + x_I \leftarrow \text{clip}(x_I, \pm I_{\max}) +\end{equation*} +This prevents controller windup even when drive signals are saturated. + +% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +\subsubsection{Voltage Saturation} +\label{sec:voltage_saturation} + +In \texttt{position} and \texttt{velocity} modes, the computed voltage $v$ can be arbitrarily large (proportional to the error). Real motor drivers are limited by their supply voltage. When \atVMAX{} is set ($v_{\max} > 0$), a voltage clamp is applied before the motor equations: +\begin{equation*} + v \leftarrow \text{clip}(v, \pm v_{\max}) + \label{eq:vlimit} +\end{equation*} +This differs from \texttt{ctrlrange} (clamping user command $u$) and \texttt{forcerange} (clamping output torque). In position and velocity modes, \texttt{ctrlrange} limits the setpoint while \atVMAX{} limits the drive signal. In voltage mode ($v = u$), both clamp the voltage; if both are set, the tighter limit wins. + +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lll@{}} +\toprule +Attribute & Symbol & Units \\ +\midrule +\atKP{} & $k_p$ & mode-dependent \\ +\atKI{} & $k_i$ & mode-dependent \\ +\atKD{} & $k_d$ & V$\cdot$s/rad \\ +\atSLEW{} & $s$ & ctrl-units/s \\ +\atIMAXINT{} & $I_{\max}$ & mode-dependent \\ +\atVMAX{} & $v_{\max}$ & Volt \\ +\bottomrule +\end{tabular} +\caption{Controller attributes.} +\label{tab:controller_attrs} +\end{table} + +% --------------------------------------------------------------------------- +% Low-Level Semantics +% --------------------------------------------------------------------------- +\subsection{Low-Level Semantics} +\label{sec:array_semantics} + +The \texttt{dcmotor} actuator uses the enum value types \texttt{mjGAIN\_DCMOTOR}, \texttt{mjDYN\_DCMOTOR}, \texttt{mjBIAS\_DCMOTOR}, and populates the rows of several \texttt{mjModel} arrays (all \texttt{actuator\_*}), as follows: + +\begin{table}[H] +\centering +\footnotesize +\begin{tabular}{@{}llll@{}} +\toprule +Array & Index & Symbol & Description \\ +\midrule +\texttt{gainprm} & 0 & $R$ & Resistance ($\Omega$) \\ +& 1 & $K$ & Motor constant (N$\cdot$m/A) \\ +& 2 & $\alpha$ & Resistance coeff.\ ($\text{K}^{-1}$) \\ +& 3 & $T_0$ & Reference temperature (\textdegree C) \\ +& 4 & $k_p$ & Controller proportional gain \\ +& 5 & $k_i$ & Controller integral gain \\ +& 6 & $k_d$ & Controller derivative gain \\ +& 7 & $v_{\max}$ & Voltage saturation (V) \\ +& 8 & --- & Input mode (0:\ $v$, 1:\ $\theta$, 2:\ $\dot\theta$) \\ +\midrule +\texttt{dynprm} & 0 & $t_e$ & Electrical time constant (s) \\ +& 1 & $(di{/}dt)_{\max}$ & Current rate limit (A/s) \\ +& 2 & $R_T$ & Thermal resistance ($\text{K}$/W) \\ +& 3 & $C$ & Thermal capacitance (J/$\text{K}$) \\ +& 4 & $T_a$ & Ambient temperature (\textdegree C) \\ +& 5 & $\sigma_0$ & LuGre bristle stiffness \\ +& 6 & $\sigma_1$ & LuGre bristle damping \\ +& 7 & $s$ & Controller slew rate \\ +& 8 & $I_{\max}$ & Integral limit (anti-windup) \\ +\midrule +\texttt{biasprm} & 0 & $A$ & Cogging amplitude (N$\cdot$m) \\ +& 1 & $N_p$ & Cogging periodicity \\ +& 2 & $\phi$ & Cogging phase (rad) \\ +& 3 & $\tau_c$ & LuGre Coulomb fric. (N$\cdot$m) \\ +& 4 & $\tau_s$ & LuGre static fric. (N$\cdot$m) \\ +& 5 & $\omega_s$ & Stribeck velocity (rad/s) \\ +\midrule +\texttt{forcerange} & 0 & $-\tau_{\max}$ & Minimum torque (N$\cdot$m) \\ +& 1 & $\tau_{\max}$ & Maximum torque (N$\cdot$m) \\ +\midrule +\texttt{damping} & 0 & $B_1 (+\sigma_2)$ & Linear (+ LuGre viscous) \\ +& 1 & $B_2$ & Quadratic \\ +& 2 & $B_3$ & Cubic \\ +\midrule +\texttt{armature} & 0 & $J_r$ & Actuator armature \\ +\midrule +\texttt{gear} & 0 & $N$ & Gear ratio \\ +\bottomrule +\end{tabular} +\caption{\texttt{mjModel} array semantics for the \texttt{dcmotor} actuator.} +\label{tab:array_semantics} +\end{table} + +\paragraph{Runtime mutability.} +Most \texttt{mjModel} parameters listed above may be freely modified at runtime for system identification or gain tuning. However, five parameters control the \emph{number} of activation states, which is determined at compile time and cannot change during simulation. Toggling any of the following parameters between zero and positive after compilation is an error: +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}llcl@{}} +\toprule +Parameter & Storage & State & Semantics \\ +\midrule +$s$ & \texttt{dynprm[7]} & $u_{\text{prev}}$ & previous control \\ +$k_i$ & \texttt{gainprm[5]} & $x_I$ & controller integral \\ +$R_T, C$ & \texttt{dynprm[2,3]} & $T$ & temperature rise \\ +$\sigma_0$ & \texttt{dynprm[5]} & $z$ & bristle deflection \\ +$t_e$ & \texttt{dynprm[0]} & $i$ & winding current \\ +\bottomrule +\end{tabular} +\caption{Compile-time state switches, allocated in \texttt{act} in the order shown. Do not toggle between zero and positive at runtime.} +\end{table} + + +\onecolumn +%============================================================================= +% DATASHEET MAPPING +%============================================================================= +\section{Datasheet Mapping} +\label{sec:datasheet} + +Table~\ref{tab:datasheet} maps commercial motor datasheet specifications to attributes. The left column shows the datasheet entry as typically labeled by motor manufacturers; the right column shows the corresponding \texttt{dcmotor} MJCF attribute, when one exists. Derived quantities that are not direct attributes (gradient, mechanical time constant) are included for completeness. + +\begin{table}[H] +\centering +\small +\begin{tabular}{@{}lllll@{}} +\toprule +Specification & Symbol & Formula / Note & Datasheet Symbol & Attribute \\ +\midrule +Resistance & $R$ & Terminal resistance & R, Ra & \atR{} \\ +Torque Constant & $K_t$ & $\tau = K_t i$ & kt, km & \atKt{} \\ +Back-EMF Constant & $K_e$ & $v_{\text{back}} = K_e \omega$ & ke & \atKe{} \\ +Speed Constant & $K_v$ & $K_v = 1/K_e$ & kn, kv & $1/$\atKe{} \\ +Nominal Voltage & $v_n$ & Rated voltage (e.g., 24V) & Un, VDC & \atVM{} \\ +No-load Speed & $\omega_0$ & $\omega_0 \approx v_n / K$ & n0 & \atNLS{} \\ +Stall Torque & $\tau_0$ & $\tau_0 = K v_n / R$ & MH, Ts & \atSTALL{} \\ +Stall Current & $i_s$ & Max.\ possible: $i_s = v_n / R$ & IA & \\ +Nominal Current & $i_{\max}$ & Thermal limit (continuous) & Ic, IN & \atIMAX{} \\ +Peak Current & $i_{\text{peak}}$ & Short-term limit & Ipk & \\ +\midrule +Coulomb Friction & $\tau_c$ & Dry friction opposing motion & $T_f$ & \texttt{frictionloss} (joint)\\ +Viscous Friction & $B$ & Drag $\propto \omega$ & $C_v$ & \texttt{damping} \\ +Rotor Inertia & $J_r$ & Reflected: $J_{\text{eff}} = J_r N^2$ & J, Jm & \texttt{armature} \\ +Gear Ratio & $N$ & Reduction ratio & $N$, $i$ & \texttt{gear} \\ +Gearbox Efficiency & $\eta$ & Fold into $K$: use $\eta K$ & $\eta$ & \\ +Cogging Amplitude & $A$ & Peak cogging torque & --- & \atCOGA{} \\ +Cogging Periodicity & $N_p$ & Poles $\times$ slots/pole & --- & \atCOGP{} \\ +\midrule +Nominal Torque & $\tau_n$ & $\tau_n = K \cdot i_{\max}$ & $M_N$, $T_c$ & \\ +No-load Current & $i_0$ & Friction: Eq.~\eqref{eq:noload} & $I_0$ & \\ +Gradient & $\partial\omega/\partial\tau$ & $-R / K^2$ & $\Delta n / \Delta M$ & \\ +Mech.\ Time Const. & $t_m$ & $t_m = R\,J / K^2$ & $\tau_m$ & \\ +\midrule +Inductance & $L$ & Terminal inductance & L & \atL{} \\ +Elec.\ Time Const. & $t_e$ & $t_e = L / R$ & $\tau_e$ & \atTE{} \\ +\midrule +Thermal Resistance & $R_T$ & Winding-to-ambient & Rth & \atRT{} \\ +Thermal Capacitance & $C$ & $C = t_T / R_T$ & $C_{\text{th}}$ & \atTC{} \\ +Thermal Time Const. & $t_T$ & $t_T = R_T C$ & $\tau_{\text{th}}$ & \atTT{} \\ +Ref.\ Temperature & $T_0$ & Temperature at which $R$ is specified & $T_{\text{ref}}$ & \atTREF{} \\ +Ambient Temperature & $T_a$ & Operating environment & --- & \atTAMB{} \\ +Max.\ Winding Temp. & $T_{\max}$ & Absolute limit & $T_{\max}$ & \\ +Res.\ Temp.\ Coeff. & $\alpha$ & $\approx 0.0039\, \text{K}^{-1}$ (copper) & $\alpha_{\text{Cu}}$ & \atALPHA{} \\ +\bottomrule +\end{tabular} +\caption{Datasheet parameters and their relation to model constants. Groups: electrical, mechanical, derived, inductance, thermal.} +\label{tab:datasheet} +\end{table} + +\small +\bibliographystyle{ieeetr} +\bibliography{refs} + +\end{document} diff --git a/doc/dcmotor/refs.bib b/doc/dcmotor/refs.bib new file mode 100644 index 0000000000..29708bc7d3 --- /dev/null +++ b/doc/dcmotor/refs.bib @@ -0,0 +1,82 @@ +@article{dewit95, + author = {Canudas de Wit, C. and Olsson, H. and {\AA}str{\"o}m, K. J. and Lischinsky, P.}, + title = {{A New Model for Control of Systems with Friction}}, + journal = {IEEE Transactions on Automatic Control}, + volume = {40}, + number = {3}, + pages = {419--425}, + year = {1995}, + month = mar, +} + +@article{lugre_revisited, + author = {{\AA}str{\"o}m, K. J. and Canudas de Wit, C.}, + title = {{Revisiting the LuGre Friction Model}}, + journal = {IEEE Control Systems Magazine}, + volume = {28}, + number = {6}, + pages = {101--114}, + year = {2008}, + month = dec, +} + +@techreport{dahl68, + author = {Dahl, P.}, + title = {{A Solid Friction Model}}, + institution = {The Aerospace Corporation}, + address = {El Segundo, CA}, + number = {TOR-0158(3107-18)-1}, + year = {1968}, +} + +@book{hughes2019, + author = {Hughes, Austin and Drury, Bill}, + title = {{Electric Motors and Drives: Fundamentals, Types and Applications}}, + edition = {5th}, + publisher = {Newnes}, + year = {2019}, +} + +@book{tedrake2024, + author = {Tedrake, Russ}, + title = {{Underactuated Robotics: Algorithms for Walking, Running, + Swimming, Flying, and Manipulation}}, + publisher = {MIT}, + year = {2024}, + note = {Course notes for MIT 6.832, \url{https://underactuated.mit.edu}}, +} + +@article{isaaclab2025, + author = {Mittal, Mayank and Yu, Calvin and Yu, Qinxi and Liu, Jingzhou + and Rudin, Nikita and Hoeller, David and Yuan, Jia Lin + and Singh, Ritvik and Guo, Yunrong and Mazhar, Hammad + and Mandlekar, Ajay and Babich, Buck and State, Gavriel + and Hutter, Marco and Garg, Animesh}, + title = {{Isaac Lab: A Unified and Modular Framework for Robot Learning}}, + journal = {arXiv preprint arXiv:2502.11048}, + year = {2025}, +} + +@article{stribeck1902, + author = {Stribeck, R.}, + title = {{Die wesentlichen Eigenschaften der Gleit- und Rollenlager}}, + journal = {Zeitschrift des Vereines Deutscher Ingenieure}, + volume = {46}, + pages = {1341--1348, 1432--1438, 1463--1470}, + year = {1902}, +} + +@misc{maxon_formulas, + author = {{Maxon Motor AG}}, + title = {{Key Information on Maxon DC Motors and Maxon EC Motors}}, + howpublished = {\url{https://www.maxongroup.com}}, + year = {2024}, + note = {{Maxon} Academy Technical Notes}, +} + +@misc{simscape_dcmotor, + author = {{MathWorks}}, + title = {{DC Motor --- Simscape Electrical Block Reference}}, + howpublished = {\url{https://www.mathworks.com/help/sps/ref/dcmotor.html}}, + year = {2024}, +} diff --git a/doc/includes/references.h b/doc/includes/references.h index bc8a548d52..fd855a1c1f 100644 --- a/doc/includes/references.h +++ b/doc/includes/references.h @@ -635,19 +635,22 @@ typedef enum mjtDyn_ { // type of actuator dynamics mjDYN_INTEGRATOR, // integrator: da/dt = u mjDYN_FILTER, // linear filter: da/dt = (u-a) / tau mjDYN_FILTEREXACT, // linear filter: da/dt = (u-a) / tau, with exact integration - mjDYN_MUSCLE, // piece-wise linear filter with two time constants + mjDYN_MUSCLE, // piecewise linear filter with two time constants + mjDYN_DCMOTOR, // DC motor electrical dynamics mjDYN_USER // user-defined dynamics type } mjtDyn; typedef enum mjtGain_ { // type of actuator gain mjGAIN_FIXED = 0, // fixed gain mjGAIN_AFFINE, // const + kp*length + kv*velocity mjGAIN_MUSCLE, // muscle FLV curve computed by mju_muscleGain() + mjGAIN_DCMOTOR, // DC motor gain: K or K/R mjGAIN_USER // user-defined gain type } mjtGain; typedef enum mjtBias_ { // type of actuator bias mjBIAS_NONE = 0, // no bias mjBIAS_AFFINE, // const + kp*length + kv*velocity mjBIAS_MUSCLE, // muscle passive force computed by mju_muscleBias() + mjBIAS_DCMOTOR, // DC motor bias: back-EMF, cogging, LuGre friction mjBIAS_USER // user-defined bias type } mjtBias; typedef enum mjtObj_ { // type of MujoCo object @@ -3659,6 +3662,10 @@ const char* mjs_setToMuscle(mjsActuator* actuator, double timeconst[2], double t double range[2], double force, double scale, double lmin, double lmax, double vmax, double fpmax, double fvmax); const char* mjs_setToAdhesion(mjsActuator* actuator, double gain); +const char* mjs_setToDCMotor(mjsActuator* actuator, double motorconst[2], double resistance, + double nominal[3], double saturation[4], double inductance[2], + double cogging[3], double controller[5], double thermal[6], + double lugre[6], int input_mode); mjsMesh* mjs_addMesh(mjSpec* s, const mjsDefault* def); mjsHField* mjs_addHField(mjSpec* s); mjsSkin* mjs_addSkin(mjSpec* s); diff --git a/include/mujoco/mjmodel.h b/include/mujoco/mjmodel.h index 2f9cdd35db..49cfee0bd4 100644 --- a/include/mujoco/mjmodel.h +++ b/include/mujoco/mjmodel.h @@ -244,7 +244,8 @@ typedef enum mjtDyn_ { // type of actuator dynamics mjDYN_INTEGRATOR, // integrator: da/dt = u mjDYN_FILTER, // linear filter: da/dt = (u-a) / tau mjDYN_FILTEREXACT, // linear filter: da/dt = (u-a) / tau, with exact integration - mjDYN_MUSCLE, // piece-wise linear filter with two time constants + mjDYN_MUSCLE, // piecewise linear filter with two time constants + mjDYN_DCMOTOR, // DC motor electrical dynamics mjDYN_USER // user-defined dynamics type } mjtDyn; @@ -253,6 +254,7 @@ typedef enum mjtGain_ { // type of actuator gain mjGAIN_FIXED = 0, // fixed gain mjGAIN_AFFINE, // const + kp*length + kv*velocity mjGAIN_MUSCLE, // muscle FLV curve computed by mju_muscleGain() + mjGAIN_DCMOTOR, // DC motor gain: K or K/R mjGAIN_USER // user-defined gain type } mjtGain; @@ -261,6 +263,7 @@ typedef enum mjtBias_ { // type of actuator bias mjBIAS_NONE = 0, // no bias mjBIAS_AFFINE, // const + kp*length + kv*velocity mjBIAS_MUSCLE, // muscle passive force computed by mju_muscleBias() + mjBIAS_DCMOTOR, // DC motor bias: back-EMF, cogging, LuGre friction mjBIAS_USER // user-defined bias type } mjtBias; diff --git a/include/mujoco/mjplugin.h b/include/mujoco/mjplugin.h index 7c6300b9b9..09f3361c53 100644 --- a/include/mujoco/mjplugin.h +++ b/include/mujoco/mjplugin.h @@ -183,38 +183,43 @@ struct mjSDF_ { typedef struct mjSDF_ mjSDF; #if defined(__has_attribute) - #if __has_attribute(constructor) - #define mjPLUGIN_LIB_INIT __attribute__((constructor)) static void _mjplugin_init(void) - #endif // __has_attribute(constructor) - -#elif defined(_MSC_VER) - - #ifndef mjDLLMAIN - #define mjDLLMAIN DllMain + #define mjPLUGIN_LIB_INIT(n) \ + static void _mj_init_##n(void) __attribute__((constructor)); \ + static void _mj_init_##n(void) #endif - - #if !defined(mjEXTERNC) - #if defined(__cplusplus) - #define mjEXTERNC extern "C" +#elif defined(_MSC_VER) + // on x86, symbols are decorated with a leading underscore + #ifdef _M_IX86 + #define LINKER_NAME "__mj_ptr_" #else - #define mjEXTERNC - #endif // defined(__cplusplus) - #endif // !defined(mjEXTERNC) - - // NOLINTBEGIN(runtime/int) - #define mjPLUGIN_LIB_INIT \ - static void _mjplugin_dllmain(void); \ - mjEXTERNC int __stdcall mjDLLMAIN(void* hinst, unsigned long reason, void* reserved) { \ - if (reason == 1) { \ - _mjplugin_dllmain(); \ - } \ - return 1; \ - } \ - static void _mjplugin_dllmain(void) - // NOLINTEND(runtime/int) - -#endif // defined(_MSC_VER) + #define LINKER_NAME "_mj_ptr_" + #endif + + #pragma section(".CRT$XCU", read) + + #if !defined(mjEXTERNC) + #if defined(__cplusplus) + #define mjEXTERNC extern "C" + #else + #define mjEXTERNC + #endif // defined(__cplusplus) + #endif // !defined(mjEXTERNC) + + #define mjPLUGIN_LIB_INIT(n) \ + static void __cdecl _mj_init_##n(void); \ + /* use mjEXTERNC to prevent C++ name mangling */ \ + /* allocate the function pointer to the .CRT$XCU section of the executable */ \ + /* functions in this section are executed on startup before calling main() */ \ + mjEXTERNC __declspec(allocate(".CRT$XCU")) \ + void (__cdecl * _mj_ptr_##n)(void) = _mj_init_##n; \ + /* Force the linker to include the pointer symbol */ \ + __pragma(comment(linker, "/include:" LINKER_NAME #n)) \ + static void __cdecl _mj_init_##n(void) + +#else + #error "Unknown compiler: Plugin registration not supported." +#endif // function pointer type for mj_loadAllPluginLibraries callback typedef void (*mjfPluginLibraryLoadCallback)(const char* filename, int first, int count); diff --git a/include/mujoco/mujoco.h b/include/mujoco/mujoco.h index 362f16e9b6..ba4963606e 100644 --- a/include/mujoco/mujoco.h +++ b/include/mujoco/mujoco.h @@ -1726,6 +1726,12 @@ MJAPI const char* mjs_setToMuscle(mjsActuator* actuator, double timeconst[2], do // Set actuator to active adhesion; return error if any. MJAPI const char* mjs_setToAdhesion(mjsActuator* actuator, double gain); +// Set actuator to DC motor; return error if any. +MJAPI const char* mjs_setToDCMotor(mjsActuator* actuator, double motorconst[2], double resistance, + double nominal[3], double saturation[4], double inductance[2], + double cogging[3], double controller[5], double thermal[6], + double lugre[6], int input_mode); + //---------------------------------- Assets -------------------------------------------------------- diff --git a/mjx/cuda_requirements.txt b/mjx/cuda_requirements.txt index 0f9ae0d572..2ced5d1285 100644 --- a/mjx/cuda_requirements.txt +++ b/mjx/cuda_requirements.txt @@ -16,8 +16,8 @@ jax-cuda12-pjrt==0.5.3; python_version >= '3.10' \ jax-cuda12-pjrt==0.4.30; python_version == '3.9' \ --hash=sha256:895d0198ad99638fcaf976c47592e2a543eef79ea15fabd24a402d055390c328 \ --hash=sha256:c36fb1e0c236563bf3a87e70f4d1ab28a31d7cf5d722c9ede30c4172116e8bcb -warp-lang==1.11.1 \ - --hash=sha256:1ad11f1fa775269e991a3d55039152c8a504baf86701c849b485cb8e66c49d15 \ - --hash=sha256:8b098f41e71d421d80ee7562e38aa8380ff6b0d3b4c6ee866cfbdef733ac5bdc \ - --hash=sha256:5d0904b0eefcc81f39ba65375427a3de99006088aa43e24a9011263f07d0cd07 \ - --hash=sha256:15dc10aa51fb0fdbe1ca16d52e5fadca35a47ffd9d0c636826506f96bb2e7c41 +warp-lang==1.12.0 \ + --hash=sha256:c78c3701d5cad86c30ef5017410d294ec46a396bb0d502ee1c98743494f3a62f \ + --hash=sha256:a1436f60a1881cd94f787e751a83fc0987626be2d3e2b4e74c64a6947c6d1266 \ + --hash=sha256:a2d6decba693aba5b828573c4414fd6a3f4c4a934db9c322736ef2b3fa99fe76 \ + --hash=sha256:697248edd2f1e2952f50e3db33b214af76173641a8894aacc467bed6dc247f8a diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/__init__.py b/mjx/mujoco/mjx/third_party/mujoco_warp/__init__.py index 1ff05ff6a6..40653b62dc 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/__init__.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/__init__.py @@ -64,6 +64,7 @@ from mujoco.mjx.third_party.mujoco_warp._src.render import render as render from mujoco.mjx.third_party.mujoco_warp._src.render_util import get_depth as get_depth from mujoco.mjx.third_party.mujoco_warp._src.render_util import get_rgb as get_rgb +from mujoco.mjx.third_party.mujoco_warp._src.render_util import get_segmentation as get_segmentation from mujoco.mjx.third_party.mujoco_warp._src.sensor import energy_pos as energy_pos from mujoco.mjx.third_party.mujoco_warp._src.sensor import energy_vel as energy_vel from mujoco.mjx.third_party.mujoco_warp._src.sensor import sensor_acc as sensor_acc @@ -92,6 +93,7 @@ from mujoco.mjx.third_party.mujoco_warp._src.types import BiasType as BiasType from mujoco.mjx.third_party.mujoco_warp._src.types import BroadphaseFilter as BroadphaseFilter from mujoco.mjx.third_party.mujoco_warp._src.types import BroadphaseType as BroadphaseType +from mujoco.mjx.third_party.mujoco_warp._src.types import Callback as Callback from mujoco.mjx.third_party.mujoco_warp._src.types import ConeType as ConeType from mujoco.mjx.third_party.mujoco_warp._src.types import Constraint as Constraint from mujoco.mjx.third_party.mujoco_warp._src.types import Contact as Contact diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/bvh.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/bvh.py index c40fcfc98f..58f7b2213a 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/bvh.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/bvh.py @@ -189,12 +189,12 @@ def _compute_bvh_bounds( upper_out: wp.array(dtype=wp.vec3), group_out: wp.array(dtype=int), ): - world_id, geom_local_id = wp.tid() + worldid, geom_local_id = wp.tid() geom_id = enabled_geom_ids[geom_local_id] - pos = geom_xpos_in[world_id, geom_id] - rot = geom_xmat_in[world_id, geom_id] - size = geom_size[world_id % geom_size.shape[0], geom_id] + pos = geom_xpos_in[worldid, geom_id] + rot = geom_xmat_in[worldid, geom_id] + size = geom_size[worldid % geom_size.shape[0], geom_id] type = geom_type[geom_id] # TODO: Investigate branch elimination with static loop unrolling @@ -218,9 +218,9 @@ def _compute_bvh_bounds( hfield_center = pos + rot[:, 2] * size[2] lower_bound, upper_bound = _compute_box_bounds(hfield_center, rot, size) - lower_out[world_id * bvh_ngeom + geom_local_id] = lower_bound - upper_out[world_id * bvh_ngeom + geom_local_id] = upper_bound - group_out[world_id * bvh_ngeom + geom_local_id] = world_id + lower_out[worldid * bvh_ngeom + geom_local_id] = lower_bound + upper_out[worldid * bvh_ngeom + geom_local_id] = upper_bound + group_out[worldid * bvh_ngeom + geom_local_id] = worldid @wp.kernel @@ -235,14 +235,70 @@ def compute_bvh_group_roots( group_root_out[tid] = root +@wp.kernel +def _compute_flex_bvh_bounds( + # Model: + flex_vertadr: wp.array(dtype=int), + flex_vertnum: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), + # Data in: + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), + # In: + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + bvh_ngeom: int, + total_bvh_size: int, + # Out: + lower_out: wp.array(dtype=wp.vec3), + upper_out: wp.array(dtype=wp.vec3), + group_out: wp.array(dtype=int), +): + worldid, flexlocalid = wp.tid() + + flex_id = flex_geom_flexid[flexlocalid] + edge_id = flex_geom_edgeid[flexlocalid] + out_idx = worldid * total_bvh_size + bvh_ngeom + flexlocalid + radius = flex_radius[flex_id] + inflate = wp.vec3(radius, radius, radius) + + if edge_id >= 0: # capsule (1D edge) + edge = flex_edge[edge_id] + vert_adr = flex_vertadr[flex_id] + v0 = flexvert_xpos_in[worldid, vert_adr + edge[0]] + v1 = flexvert_xpos_in[worldid, vert_adr + edge[1]] + lower_out[out_idx] = wp.min(v0, v1) - inflate + upper_out[out_idx] = wp.max(v0, v1) + inflate + else: # mesh (2D/3D) + vert_adr = flex_vertadr[flex_id] + nvert = flex_vertnum[flex_id] + min_bound = wp.vec3(MJ_MAXVAL, MJ_MAXVAL, MJ_MAXVAL) + max_bound = wp.vec3(-MJ_MAXVAL, -MJ_MAXVAL, -MJ_MAXVAL) + for i in range(nvert): + v = flexvert_xpos_in[worldid, vert_adr + i] + min_bound = wp.min(min_bound, v) + max_bound = wp.max(max_bound, v) + lower_out[out_idx] = min_bound - inflate + upper_out[out_idx] = max_bound + inflate + + group_out[out_idx] = worldid + + def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, nworld: int): """Build a global BVH for all geometries in all worlds.""" + total_bvh_size = rc.bvh_ngeom + rc.bvh_nflexgeom + geom_type = wp.array(mjm.geom_type, dtype=int) geom_dataid = wp.array(mjm.geom_dataid, dtype=int) geom_size = wp.array(np.tile(mjm.geom_size[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) geom_xpos = wp.array(np.tile(mjd.geom_xpos[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) geom_xmat = wp.array(np.tile(mjd.geom_xmat.reshape(mjm.ngeom, 3, 3)[np.newaxis, :, :, :], (nworld, 1, 1, 1)), dtype=wp.mat33) + flex_vertadr = wp.array(mjm.flex_vertadr, dtype=int) + flex_vertnum = wp.array(mjm.flex_vertnum, dtype=int) + flex_edge = wp.array(mjm.flex_edge, dtype=wp.vec2i) + flex_radius = wp.array(mjm.flex_radius, dtype=float) + wp.launch( kernel=_compute_bvh_bounds, dim=(nworld, rc.bvh_ngeom), @@ -252,7 +308,7 @@ def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, geom_size, geom_xpos, geom_xmat, - rc.bvh_ngeom, + total_bvh_size, rc.enabled_geom_ids, rc.mesh_bounds_size, rc.hfield_bounds_size, @@ -262,6 +318,26 @@ def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, ], ) + flexvert_xpos = wp.array(np.tile(mjd.flexvert_xpos[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) + wp.launch( + kernel=_compute_flex_bvh_bounds, + dim=(nworld, rc.bvh_nflexgeom), + inputs=[ + flex_vertadr, + flex_vertnum, + flex_edge, + flex_radius, + flexvert_xpos, + rc.flex_geom_flexid, + rc.flex_geom_edgeid, + rc.bvh_ngeom, + total_bvh_size, + rc.lower, + rc.upper, + rc.group, + ], + ) + bvh = wp.Bvh(rc.lower, rc.upper, groups=rc.group, constructor="sah") # BVH handle must be stored to avoid garbage collection @@ -277,6 +353,8 @@ def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, def refit_scene_bvh(m: Model, d: Data, rc: RenderContext): + total_bvh_size = rc.bvh_ngeom + rc.bvh_nflexgeom + wp.launch( kernel=_compute_bvh_bounds, dim=(d.nworld, rc.bvh_ngeom), @@ -286,7 +364,7 @@ def refit_scene_bvh(m: Model, d: Data, rc: RenderContext): m.geom_size, d.geom_xpos, d.geom_xmat, - rc.bvh_ngeom, + total_bvh_size, rc.enabled_geom_ids, rc.mesh_bounds_size, rc.hfield_bounds_size, @@ -296,6 +374,26 @@ def refit_scene_bvh(m: Model, d: Data, rc: RenderContext): ], ) + if rc.bvh_nflexgeom > 0: + wp.launch( + kernel=_compute_flex_bvh_bounds, + dim=(d.nworld, rc.bvh_nflexgeom), + inputs=[ + m.flex_vertadr, + m.flex_vertnum, + m.flex_edge, + m.flex_radius, + d.flexvert_xpos, + rc.flex_geom_flexid, + rc.flex_geom_edgeid, + rc.bvh_ngeom, + total_bvh_size, + rc.lower, + rc.upper, + rc.group, + ], + ) + rc.bvh.refit() @@ -500,6 +598,12 @@ def build_hfield_bvh( @wp.kernel def accumulate_flex_vertex_normals( # Model: + nflex: int, + flex_dim: wp.array(dtype=int), + flex_vertadr: wp.array(dtype=int), + flex_elemadr: wp.array(dtype=int), + flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_elem: wp.array(dtype=int), # Data in: flexvert_xpos_in: wp.array2d(dtype=wp.vec3), @@ -509,10 +613,22 @@ def accumulate_flex_vertex_normals( """Accumulate per-vertex normals by summing adjacent face normals.""" worldid, elemid = wp.tid() - elem_base = elemid * 3 - i0 = flex_elem[elem_base + 0] - i1 = flex_elem[elem_base + 1] - i2 = flex_elem[elem_base + 2] + for i in range(nflex): + locid = elemid - flex_elemadr[i] + if locid >= 0 and locid < flex_elemnum[i]: + f = i + break + + if flex_dim[f] == 1 or flex_dim[f] == 3: + return + + local_elemid = elemid - flex_elemadr[f] + elem_adr = flex_elemdataadr[f] + vert_adr = flex_vertadr[f] + elem_base = elem_adr + local_elemid * 3 + i0 = vert_adr + flex_elem[elem_base + 0] + i1 = vert_adr + flex_elem[elem_base + 1] + i2 = vert_adr + flex_elem[elem_base + 2] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] @@ -611,11 +727,12 @@ def _build_flex_2d_elements( @wp.kernel def _build_flex_2d_sides( + # Model: + flex_shell: wp.array(dtype=int), # Data in: flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: flexvert_norm_in: wp.array2d(dtype=wp.vec3), - flex_shell_in: wp.array(dtype=int), shell_adr: int, vert_adr: int, face_offset: int, @@ -635,8 +752,8 @@ def _build_flex_2d_sides( worldid, shellid = wp.tid() base = shell_adr + 2 * shellid - i0 = vert_adr + flex_shell_in[base + 0] - i1 = vert_adr + flex_shell_in[base + 1] + i0 = vert_adr + flex_shell[base + 0] + i1 = vert_adr + flex_shell[base + 1] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] @@ -672,10 +789,11 @@ def _build_flex_2d_sides( @wp.kernel def _build_flex_3d_shells( + # Model: + flex_shell: wp.array(dtype=int), # Data in: flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: - flex_shell_in: wp.array(dtype=int), shell_adr: int, vert_adr: int, face_offset: int, @@ -693,9 +811,9 @@ def _build_flex_3d_shells( worldid, shellid = wp.tid() base = shell_adr + shellid * 3 - i0 = vert_adr + flex_shell_in[base + 0] - i1 = vert_adr + flex_shell_in[base + 1] - i2 = vert_adr + flex_shell_in[base + 2] + i0 = vert_adr + flex_shell[base + 0] + i1 = vert_adr + flex_shell[base + 1] + i2 = vert_adr + flex_shell[base + 2] face_id = worldid * nface + face_offset + shellid base = face_id * 3 @@ -716,163 +834,163 @@ def _build_flex_3d_shells( @wp.kernel -def _update_flex_face_points( +def _update_flex_2d_face_points( # Model: - nflex: int, - flex_dim: wp.array(dtype=int), flex_vertadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), + flex_shelldataadr: wp.array(dtype=int), flex_elem: wp.array(dtype=int), + flex_shell: wp.array(dtype=int), + flex_radius: wp.array(dtype=float), # Data in: flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: - flex_shell_in: wp.array(dtype=int), flexvert_norm_in: wp.array2d(dtype=wp.vec3), - flex_elemdataadr: wp.array(dtype=int), - flex_shelldataadr: wp.array(dtype=int), - flex_faceadr: wp.array(dtype=int), - flex_radius: wp.array(dtype=float), - flex_workadr: wp.array(dtype=int), - flex_worknum: wp.array(dtype=int), - nfaces: int, + flex_id: int, + nface: int, smooth: bool, # Out: face_point_out: wp.array(dtype=wp.vec3), ): worldid, workid = wp.tid() - # identify which flex this work item belongs to - f = int(0) - locid = int(0) - for i in range(nflex): - locid = workid - flex_workadr[i] - if locid >= 0 and locid < flex_worknum[i]: - f = i - break - - dim = flex_dim[f] - face_offset = flex_faceadr[f] - world_face_offset = worldid * nfaces - vert_adr = flex_vertadr[f] + elem_adr = flex_elemdataadr[flex_id] + vert_adr = flex_vertadr[flex_id] + radius = flex_radius[flex_id] + nelem = flex_elemnum[flex_id] + world_face_offset = worldid * nface - if dim == 2: - radius = flex_radius[f] - elem_count = flex_elemnum[f] - - if locid < elem_count: - # 2D element faces - elemid = locid - elem_adr = flex_elemdataadr[f] - ebase = elem_adr + elemid * 3 - i0 = vert_adr + flex_elem[ebase + 0] - i1 = vert_adr + flex_elem[ebase + 1] - i2 = vert_adr + flex_elem[ebase + 2] - - v0 = flexvert_xpos_in[worldid, i0] - v1 = flexvert_xpos_in[worldid, i1] - v2 = flexvert_xpos_in[worldid, i2] - - # TODO: Use static conditional - if smooth: - n0 = flexvert_norm_in[worldid, i0] - n1 = flexvert_norm_in[worldid, i1] - n2 = flexvert_norm_in[worldid, i2] - else: - face_nrm = wp.cross(v1 - v0, v2 - v0) - face_nrm = wp.normalize(face_nrm) - n0 = face_nrm - n1 = face_nrm - n2 = face_nrm - - p0_pos = v0 + radius * n0 - p1_pos = v1 + radius * n1 - p2_pos = v2 + radius * n2 - - p0_neg = v0 - radius * n0 - p1_neg = v1 - radius * n1 - p2_neg = v2 - radius * n2 - - face_id0 = world_face_offset + face_offset + (2 * elemid) - base0 = face_id0 * 3 - face_point_out[base0 + 0] = p0_pos - face_point_out[base0 + 1] = p1_pos - face_point_out[base0 + 2] = p2_pos - - face_id1 = world_face_offset + face_offset + (2 * elemid + 1) - base1 = face_id1 * 3 - face_point_out[base1 + 0] = p0_neg - face_point_out[base1 + 1] = p1_neg - face_point_out[base1 + 2] = p2_neg - else: - # 2D shell faces - shellid = locid - elem_count - shell_adr = flex_shelldataadr[f] - sbase = shell_adr + 2 * shellid - i0 = vert_adr + flex_shell_in[sbase + 0] - i1 = vert_adr + flex_shell_in[sbase + 1] + if workid < nelem: + # 2D element faces + elemid = workid + ebase = elem_adr + elemid * 3 + i0 = vert_adr + flex_elem[ebase + 0] + i1 = vert_adr + flex_elem[ebase + 1] + i2 = vert_adr + flex_elem[ebase + 2] - v0 = flexvert_xpos_in[worldid, i0] - v1 = flexvert_xpos_in[worldid, i1] + v0 = flexvert_xpos_in[worldid, i0] + v1 = flexvert_xpos_in[worldid, i1] + v2 = flexvert_xpos_in[worldid, i2] + # TODO: Use static conditional + if smooth: n0 = flexvert_norm_in[worldid, i0] n1 = flexvert_norm_in[worldid, i1] - - shell_face_offset = face_offset + (2 * elem_count) - face_id0 = world_face_offset + shell_face_offset + (2 * shellid) - base0 = face_id0 * 3 - face_point_out[base0 + 0] = v0 + radius * n0 - face_point_out[base0 + 1] = v1 - radius * n1 - face_point_out[base0 + 2] = v1 + radius * n1 - - face_id1 = world_face_offset + shell_face_offset + (2 * shellid + 1) - base1 = face_id1 * 3 - face_point_out[base1 + 0] = v1 - radius * n1 - face_point_out[base1 + 1] = v0 + radius * n0 - face_point_out[base1 + 2] = v0 - radius * n0 + n2 = flexvert_norm_in[worldid, i2] + else: + face_nrm = wp.cross(v1 - v0, v2 - v0) + face_nrm = wp.normalize(face_nrm) + n0 = face_nrm + n1 = face_nrm + n2 = face_nrm + + p0_pos = v0 + radius * n0 + p1_pos = v1 + radius * n1 + p2_pos = v2 + radius * n2 + + p0_neg = v0 - radius * n0 + p1_neg = v1 - radius * n1 + p2_neg = v2 - radius * n2 + + face_id0 = world_face_offset + (2 * elemid) + base0 = face_id0 * 3 + face_point_out[base0 + 0] = p0_pos + face_point_out[base0 + 1] = p1_pos + face_point_out[base0 + 2] = p2_pos + + face_id1 = world_face_offset + (2 * elemid + 1) + base1 = face_id1 * 3 + face_point_out[base1 + 0] = p0_neg + face_point_out[base1 + 1] = p1_neg + face_point_out[base1 + 2] = p2_neg else: - # 3D shell faces - shellid = locid - shell_adr = flex_shelldataadr[f] - sbase = shell_adr + shellid * 3 - i0 = vert_adr + flex_shell_in[sbase + 0] - i1 = vert_adr + flex_shell_in[sbase + 1] - i2 = vert_adr + flex_shell_in[sbase + 2] + # 2D shell faces + shell_adr = flex_shelldataadr[flex_id] + shellid = workid - nelem + sbase = shell_adr + 2 * shellid + i0 = vert_adr + flex_shell[sbase + 0] + i1 = vert_adr + flex_shell[sbase + 1] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] - v2 = flexvert_xpos_in[worldid, i2] - face_id = world_face_offset + face_offset + shellid - fbase = face_id * 3 + n0 = flexvert_norm_in[worldid, i0] + n1 = flexvert_norm_in[worldid, i1] - face_point_out[fbase + 0] = v0 - face_point_out[fbase + 1] = v1 - face_point_out[fbase + 2] = v2 + shell_face_offset = 2 * nelem + face_id0 = world_face_offset + shell_face_offset + (2 * shellid) + base0 = face_id0 * 3 + face_point_out[base0 + 0] = v0 + radius * n0 + face_point_out[base0 + 1] = v1 - radius * n1 + face_point_out[base0 + 2] = v1 + radius * n1 + face_id1 = world_face_offset + shell_face_offset + (2 * shellid + 1) + base1 = face_id1 * 3 + face_point_out[base1 + 0] = v1 - radius * n1 + face_point_out[base1 + 1] = v0 + radius * n0 + face_point_out[base1 + 2] = v0 - radius * n0 -def build_flex_bvh( - mjm: mujoco.MjModel, mjd: mujoco.MjData, nworld: int, constructor: str = "sah", leaf_size: int = 2 -) -> tuple[wp.Mesh, wp.array, wp.array, wp.array, wp.array, wp.array, int]: - """Create a Warp mesh BVH from flex data.""" - if (mjm.flex_dim == 1).any(): - raise ValueError("1D Flex objects are not currently supported.") - nflex = mjm.nflex +@wp.kernel +def _update_flex_3d_face_points( + # Model: + flex_vertadr: wp.array(dtype=int), + flex_shelldataadr: wp.array(dtype=int), + flex_shell: wp.array(dtype=int), + # Data in: + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), + # In: + flex_id: int, + nface: int, + # Out: + face_point_out: wp.array(dtype=wp.vec3), +): + worldid, shellid = wp.tid() + + shell_adr = flex_shelldataadr[flex_id] + vert_adr = flex_vertadr[flex_id] + + face_id = worldid * nface + shellid + fbase = face_id * 3 + + sbase = shell_adr + shellid * 3 + i0 = vert_adr + flex_shell[sbase + 0] + i1 = vert_adr + flex_shell[sbase + 1] + i2 = vert_adr + flex_shell[sbase + 2] + + face_point_out[fbase + 0] = flexvert_xpos_in[worldid, i0] + face_point_out[fbase + 1] = flexvert_xpos_in[worldid, i1] + face_point_out[fbase + 2] = flexvert_xpos_in[worldid, i2] + + +def build_flex_bvh( + mjm: mujoco.MjModel, + mjd: mujoco.MjData, + nworld: int, + flex_id: int, + constructor: str = "sah", + leaf_size: int = 2, +) -> tuple[wp.Mesh, wp.array, wp.array, wp.array, int]: + """Create a Warp mesh BVH for a single 2D or 3D flex.""" nflexvert = mjm.nflexvert - nflexelemdata = len(mjm.flex_elem) + flex_dim = wp.array(mjm.flex_dim, dtype=int) + flex_elemadr = wp.array(mjm.flex_elemadr, dtype=int) + flex_elemnum = wp.array(mjm.flex_elemnum, dtype=int) flex_elem = wp.array(mjm.flex_elem, dtype=int) + flex_elemdataadr = wp.array(mjm.flex_elemdataadr, dtype=int) + flex_vertadr = wp.array(mjm.flex_vertadr, dtype=int) flexvert_xpos = wp.array(np.tile(mjd.flexvert_xpos[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) - flex_faceadr = [0] - for f in range(nflex): - if mjm.flex_dim[f] == 2: - flex_faceadr.append(flex_faceadr[-1] + 2 * mjm.flex_elemnum[f] + 2 * mjm.flex_shellnum[f]) - elif mjm.flex_dim[f] == 3: - flex_faceadr.append(flex_faceadr[-1] + mjm.flex_shellnum[f]) + dim = int(mjm.flex_dim[flex_id]) + nelem = int(mjm.flex_elemnum[flex_id]) + nshell = int(mjm.flex_shellnum[flex_id]) - nface = int(flex_faceadr[-1]) - flex_faceadr = flex_faceadr[:-1] + if dim == 2: + nface = 2 * nelem + 2 * nshell + else: + nface = nshell face_point = wp.empty(nface * 3 * nworld, dtype=wp.vec3) face_index = wp.empty(nface * 3 * nworld, dtype=wp.int32) @@ -883,8 +1001,8 @@ def build_flex_bvh( wp.launch( kernel=accumulate_flex_vertex_normals, - dim=(nworld, nflexelemdata // 3), - inputs=[flex_elem, flexvert_xpos], + dim=(nworld, mjm.nflexelem), + inputs=[mjm.nflex, flex_dim, flex_vertadr, flex_elemadr, flex_elemnum, flex_elemdataadr, flex_elem, flexvert_xpos], outputs=[flexvert_norm], ) @@ -894,60 +1012,56 @@ def build_flex_bvh( inputs=[flexvert_norm], ) - for f in range(nflex): - dim = mjm.flex_dim[f] - elem_adr = mjm.flex_elemdataadr[f] - nelem = mjm.flex_elemnum[f] - shell_adr = mjm.flex_shelldataadr[f] - nshell = mjm.flex_shellnum[f] - vert_adr = mjm.flex_vertadr[f] - - if dim == 2: - wp.launch( - kernel=_build_flex_2d_elements, - dim=(nworld, nelem), - inputs=[ - flex_elem, - flexvert_xpos, - flexvert_norm, - elem_adr, - vert_adr, - flex_faceadr[f], - mjm.flex_radius[f], - nface, - ], - outputs=[face_point, face_index, group], - ) + elem_adr = mjm.flex_elemdataadr[flex_id] + shell_adr = mjm.flex_shelldataadr[flex_id] + vert_adr = mjm.flex_vertadr[flex_id] - wp.launch( - kernel=_build_flex_2d_sides, - dim=(nworld, nshell), - inputs=[ - flexvert_xpos, - flexvert_norm, - flex_shell, - shell_adr, - vert_adr, - flex_faceadr[f] + 2 * nelem, - mjm.flex_radius[f], - nface, - ], - outputs=[face_point, face_index, group], - ) - elif dim == 3: - wp.launch( - kernel=_build_flex_3d_shells, - dim=(nworld, nshell), - inputs=[ - flexvert_xpos, - flex_shell, - shell_adr, - vert_adr, - flex_faceadr[f], - nface, - ], - outputs=[face_point, face_index, group], - ) + if dim == 2: + wp.launch( + kernel=_build_flex_2d_elements, + dim=(nworld, nelem), + inputs=[ + flex_elem, + flexvert_xpos, + flexvert_norm, + elem_adr, + vert_adr, + 0, # face_offset + mjm.flex_radius[flex_id], + nface, + ], + outputs=[face_point, face_index, group], + ) + + wp.launch( + kernel=_build_flex_2d_sides, + dim=(nworld, nshell), + inputs=[ + flex_shell, + flexvert_xpos, + flexvert_norm, + shell_adr, + vert_adr, + 2 * nelem, # face_offset + mjm.flex_radius[flex_id], + nface, + ], + outputs=[face_point, face_index, group], + ) + elif dim == 3: + wp.launch( + kernel=_build_flex_3d_shells, + dim=(nworld, nshell), + inputs=[ + flex_shell, + flexvert_xpos, + shell_adr, + vert_adr, + 0, # face_offset + nface, + ], + outputs=[face_point, face_index, group], + ) flex_mesh = wp.Mesh( points=face_point, @@ -965,24 +1079,23 @@ def build_flex_bvh( outputs=[group_root], ) - return ( - flex_mesh, - face_point, - group_root, - flex_shell, - flex_faceadr, - nface, - ) + return flex_mesh, group_root def refit_flex_bvh(m: Model, d: Data, rc: RenderContext): - """Refit the flex BVH.""" + """Refit per-flex BVHs.""" flexvert_norm = wp.zeros(d.flexvert_xpos.shape, dtype=wp.vec3) wp.launch( kernel=accumulate_flex_vertex_normals, - dim=(d.nworld, m.nflexelemdata // 3), + dim=(d.nworld, m.nflexelem), inputs=[ + m.nflex, + m.flex_dim, + m.flex_vertadr, + m.flex_elemadr, + m.flex_elemnum, + m.flex_elemdataadr, m.flex_elem, d.flexvert_xpos, ], @@ -991,32 +1104,49 @@ def refit_flex_bvh(m: Model, d: Data, rc: RenderContext): wp.launch( kernel=normalize_vertex_normals, - dim=(d.nworld, m.nflexvert), + dim=(d.nworld, d.flexvert_xpos.shape[1]), inputs=[flexvert_norm], ) - wp.launch( - kernel=_update_flex_face_points, - dim=(d.nworld, rc.flex_nwork), - inputs=[ - m.nflex, - m.flex_dim, - m.flex_vertadr, - m.flex_elemnum, - m.flex_elem, - d.flexvert_xpos, - rc.flex_shell, - flexvert_norm, - rc.flex_elemdataadr, - rc.flex_shelldataadr, - rc.flex_faceadr, - rc.flex_radius, - rc.flex_workadr, - rc.flex_worknum, - rc.flex_nface, - rc.flex_render_smooth, - ], - outputs=[rc.flex_face_point], - ) + for i in range(m.nflex): + if rc.flex_dim_np[i] == 1: + continue + mesh = rc.flex_mesh_registry[i] + nface = mesh.points.shape[0] // (3 * d.nworld) + + if rc.flex_dim_np[i] == 2: + wp.launch( + kernel=_update_flex_2d_face_points, + dim=(d.nworld, nface // 2), + inputs=[ + m.flex_vertadr, + m.flex_elemnum, + m.flex_elemdataadr, + m.flex_shelldataadr, + m.flex_elem, + m.flex_shell, + m.flex_radius, + d.flexvert_xpos, + flexvert_norm, + i, + nface, + rc.flex_render_smooth, + ], + outputs=[mesh.points], + ) + else: + wp.launch( + kernel=_update_flex_3d_face_points, + dim=(d.nworld, nface), + inputs=[ + m.flex_vertadr, + m.flex_shelldataadr, + m.flex_shell, + d.flexvert_xpos, + i, + nface, + ], + outputs=[mesh.points], + ) - rc.flex_mesh.refit() + mesh.refit() diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_convex.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_convex.py index 69ea9cbe3d..12d9824a44 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_convex.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_convex.py @@ -15,34 +15,35 @@ from typing import Tuple +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src.collision_core import CollisionContext -from mujoco.mjx.third_party.mujoco_warp._src.collision_core import contact_params from mujoco.mjx.third_party.mujoco_warp._src.collision_core import Geom +from mujoco.mjx.third_party.mujoco_warp._src.collision_core import contact_params from mujoco.mjx.third_party.mujoco_warp._src.collision_core import geom_collision_pair from mujoco.mjx.third_party.mujoco_warp._src.collision_core import write_contact from mujoco.mjx.third_party.mujoco_warp._src.collision_gjk import ccd from mujoco.mjx.third_party.mujoco_warp._src.collision_gjk import multicontact from mujoco.mjx.third_party.mujoco_warp._src.collision_gjk import support -from mujoco.mjx.third_party.mujoco_warp._src.collision_primitive import contact_params from mujoco.mjx.third_party.mujoco_warp._src.collision_primitive import Geom +from mujoco.mjx.third_party.mujoco_warp._src.collision_primitive import contact_params from mujoco.mjx.third_party.mujoco_warp._src.collision_primitive import geom_collision_pair from mujoco.mjx.third_party.mujoco_warp._src.collision_primitive import write_contact from mujoco.mjx.third_party.mujoco_warp._src.math import make_frame from mujoco.mjx.third_party.mujoco_warp._src.math import upper_trid_index -from mujoco.mjx.third_party.mujoco_warp._src.types import Data -from mujoco.mjx.third_party.mujoco_warp._src.types import EnableBit -from mujoco.mjx.third_party.mujoco_warp._src.types import GeomType -from mujoco.mjx.third_party.mujoco_warp._src.types import mat43 -from mujoco.mjx.third_party.mujoco_warp._src.types import mat63 from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAX_EPAFACES from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAX_EPAHORIZON from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXCONPAIR from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXVAL +from mujoco.mjx.third_party.mujoco_warp._src.types import Data +from mujoco.mjx.third_party.mujoco_warp._src.types import EnableBit +from mujoco.mjx.third_party.mujoco_warp._src.types import GeomType from mujoco.mjx.third_party.mujoco_warp._src.types import Model +from mujoco.mjx.third_party.mujoco_warp._src.types import mat43 +from mujoco.mjx.third_party.mujoco_warp._src.types import mat63 from mujoco.mjx.third_party.mujoco_warp._src.types import vec5 from mujoco.mjx.third_party.mujoco_warp._src.warp_util import cache_kernel from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope -import warp as wp # TODO(team): improve compile time to enable backward pass wp.set_module_options({"enable_backward": False}) @@ -233,6 +234,7 @@ def ccd_hfield_kernel( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -516,6 +518,7 @@ def ccd_hfield_kernel( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -572,6 +575,7 @@ def ccd_hfield_kernel( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -626,6 +630,7 @@ def ccd_hfield_kernel( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -681,6 +686,7 @@ def ccd_hfield_kernel( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -751,6 +757,7 @@ def eval_ccd_write_contact( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -871,6 +878,7 @@ def eval_ccd_write_contact( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -956,6 +964,7 @@ def ccd_kernel( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -1070,6 +1079,7 @@ def ccd_kernel( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -1156,6 +1166,7 @@ def _pair_count(p1: int, p2: int) -> Tuple[int, int]: d.contact.solimp, d.contact.dim, d.contact.geom, + d.contact.efc_address, d.contact.worldid, d.contact.type, d.contact.geomcollisionid, diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_core.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_core.py index f8294f65da..b7affa7ddc 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_core.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_core.py @@ -18,14 +18,15 @@ import dataclasses from typing import Tuple +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src.math import safe_div +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINMU +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINVAL from mujoco.mjx.third_party.mujoco_warp._src.types import ContactType from mujoco.mjx.third_party.mujoco_warp._src.types import GeomType from mujoco.mjx.third_party.mujoco_warp._src.types import mat63 -from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINMU -from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINVAL from mujoco.mjx.third_party.mujoco_warp._src.types import vec5 -import warp as wp wp.set_module_options({"enable_backward": False}) @@ -63,30 +64,30 @@ class Geom: @wp.func def geom_collision_pair( - # Model: - geom_type: wp.array(dtype=int), - geom_dataid: wp.array(dtype=int), - geom_size: wp.array2d(dtype=wp.vec3), - mesh_vertadr: wp.array(dtype=int), - mesh_vertnum: wp.array(dtype=int), - mesh_graphadr: wp.array(dtype=int), - mesh_vert: wp.array(dtype=wp.vec3), - mesh_graph: wp.array(dtype=int), - mesh_polynum: wp.array(dtype=int), - mesh_polyadr: wp.array(dtype=int), - mesh_polynormal: wp.array(dtype=wp.vec3), - mesh_polyvertadr: wp.array(dtype=int), - mesh_polyvertnum: wp.array(dtype=int), - mesh_polyvert: wp.array(dtype=int), - mesh_polymapadr: wp.array(dtype=int), - mesh_polymapnum: wp.array(dtype=int), - mesh_polymap: wp.array(dtype=int), - # Data in: - geom_xpos_in: wp.array2d(dtype=wp.vec3), - geom_xmat_in: wp.array2d(dtype=wp.mat33), - # In: - geoms: wp.vec2i, - worldid: int, + # Model: + geom_type: wp.array(dtype=int), + geom_dataid: wp.array(dtype=int), + geom_size: wp.array2d(dtype=wp.vec3), + mesh_vertadr: wp.array(dtype=int), + mesh_vertnum: wp.array(dtype=int), + mesh_graphadr: wp.array(dtype=int), + mesh_vert: wp.array(dtype=wp.vec3), + mesh_graph: wp.array(dtype=int), + mesh_polynum: wp.array(dtype=int), + mesh_polyadr: wp.array(dtype=int), + mesh_polynormal: wp.array(dtype=wp.vec3), + mesh_polyvertadr: wp.array(dtype=int), + mesh_polyvertnum: wp.array(dtype=int), + mesh_polyvert: wp.array(dtype=int), + mesh_polymapadr: wp.array(dtype=int), + mesh_polymapnum: wp.array(dtype=int), + mesh_polymap: wp.array(dtype=int), + # Data in: + geom_xpos_in: wp.array2d(dtype=wp.vec3), + geom_xmat_in: wp.array2d(dtype=wp.mat33), + # In: + geoms: wp.vec2i, + worldid: int, ) -> Tuple[Geom, Geom]: geom1 = Geom() geom2 = Geom() @@ -155,38 +156,39 @@ def geom_collision_pair( @wp.func def write_contact( - # Data in: - naconmax_in: int, - # In: - id_: int, - dist_in: float, - pos_in: wp.vec3, - frame_in: wp.mat33, - margin_in: float, - gap_in: float, - condim_in: int, - friction_in: vec5, - solref_in: wp.vec2, - solreffriction_in: wp.vec2, - solimp_in: vec5, - geoms_in: wp.vec2i, - pairid_in: wp.vec2i, - worldid_in: int, - # Data out: - contact_dist_out: wp.array(dtype=float), - contact_pos_out: wp.array(dtype=wp.vec3), - contact_frame_out: wp.array(dtype=wp.mat33), - contact_includemargin_out: wp.array(dtype=float), - contact_friction_out: wp.array(dtype=vec5), - contact_solref_out: wp.array(dtype=wp.vec2), - contact_solreffriction_out: wp.array(dtype=wp.vec2), - contact_solimp_out: wp.array(dtype=vec5), - contact_dim_out: wp.array(dtype=int), - contact_geom_out: wp.array(dtype=wp.vec2i), - contact_worldid_out: wp.array(dtype=int), - contact_type_out: wp.array(dtype=int), - contact_geomcollisionid_out: wp.array(dtype=int), - nacon_out: wp.array(dtype=int), + # Data in: + naconmax_in: int, + # In: + id_: int, + dist_in: float, + pos_in: wp.vec3, + frame_in: wp.mat33, + margin_in: float, + gap_in: float, + condim_in: int, + friction_in: vec5, + solref_in: wp.vec2, + solreffriction_in: wp.vec2, + solimp_in: vec5, + geoms_in: wp.vec2i, + pairid_in: wp.vec2i, + worldid_in: int, + # Data out: + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), + contact_includemargin_out: wp.array(dtype=float), + contact_friction_out: wp.array(dtype=vec5), + contact_solref_out: wp.array(dtype=wp.vec2), + contact_solreffriction_out: wp.array(dtype=wp.vec2), + contact_solimp_out: wp.array(dtype=vec5), + contact_dim_out: wp.array(dtype=int), + contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), + contact_worldid_out: wp.array(dtype=int), + contact_type_out: wp.array(dtype=int), + contact_geomcollisionid_out: wp.array(dtype=int), + nacon_out: wp.array(dtype=int), ) -> int: """Atomically write a detected contact into the contact output arrays. @@ -222,33 +224,35 @@ def write_contact( contact_solimp_out[cid] = solimp_in contact_type_out[cid] = contact_type contact_geomcollisionid_out[cid] = id_ + for i in range(contact_efc_address_out.shape[1]): + contact_efc_address_out[cid, i] = -1 return int(active) return 0 @wp.func def contact_params( - # Model: - geom_condim: wp.array(dtype=int), - geom_priority: wp.array(dtype=int), - geom_solmix: wp.array2d(dtype=float), - geom_solref: wp.array2d(dtype=wp.vec2), - geom_solimp: wp.array2d(dtype=vec5), - geom_friction: wp.array2d(dtype=wp.vec3), - geom_margin: wp.array2d(dtype=float), - geom_gap: wp.array2d(dtype=float), - pair_dim: wp.array(dtype=int), - pair_solref: wp.array2d(dtype=wp.vec2), - pair_solreffriction: wp.array2d(dtype=wp.vec2), - pair_solimp: wp.array2d(dtype=vec5), - pair_margin: wp.array2d(dtype=float), - pair_gap: wp.array2d(dtype=float), - pair_friction: wp.array2d(dtype=vec5), - # In: - collision_pair_in: wp.array(dtype=wp.vec2i), - collision_pairid_in: wp.array(dtype=wp.vec2i), - cid: int, - worldid: int, + # Model: + geom_condim: wp.array(dtype=int), + geom_priority: wp.array(dtype=int), + geom_solmix: wp.array2d(dtype=float), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + geom_gap: wp.array2d(dtype=float), + pair_dim: wp.array(dtype=int), + pair_solref: wp.array2d(dtype=wp.vec2), + pair_solreffriction: wp.array2d(dtype=wp.vec2), + pair_solimp: wp.array2d(dtype=vec5), + pair_margin: wp.array2d(dtype=float), + pair_gap: wp.array2d(dtype=float), + pair_friction: wp.array2d(dtype=vec5), + # In: + collision_pair_in: wp.array(dtype=wp.vec2i), + collision_pairid_in: wp.array(dtype=wp.vec2i), + cid: int, + worldid: int, ): """Resolve contact parameters for a collision pair. @@ -267,9 +271,7 @@ def contact_params( condim = pair_dim[pairid] friction = pair_friction[worldid % pair_friction.shape[0], pairid] solref = pair_solref[worldid % pair_solref.shape[0], pairid] - solreffriction = pair_solreffriction[ - worldid % pair_solreffriction.shape[0], pairid - ] + solreffriction = pair_solreffriction[worldid % pair_solreffriction.shape[0], pairid] solimp = pair_solimp[worldid % pair_solimp.shape[0], pairid] else: g1 = geoms[0] @@ -305,44 +307,33 @@ def contact_params( mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 >= MJ_MINVAL), 0.0, mix) mix = wp.where((solmix1 >= MJ_MINVAL) and (solmix2 < MJ_MINVAL), 1.0, mix) condim = wp.max(condim1, condim2) - max_geom_friction = wp.max( - geom_friction[friction_id, g1], geom_friction[friction_id, g2] - ) + max_geom_friction = wp.max(geom_friction[friction_id, g1], geom_friction[friction_id, g2]) friction = vec5( - max_geom_friction[0], - max_geom_friction[0], - max_geom_friction[1], - max_geom_friction[2], - max_geom_friction[2], + max_geom_friction[0], + max_geom_friction[0], + max_geom_friction[1], + max_geom_friction[2], + max_geom_friction[2], ) - if ( - geom_solref[solref_id, g1][0] > 0.0 - and geom_solref[solref_id, g2][0] > 0.0 - ): - solref = ( - mix * geom_solref[solref_id, g1] - + (1.0 - mix) * geom_solref[solref_id, g2] - ) + if geom_solref[solref_id, g1][0] > 0.0 and geom_solref[solref_id, g2][0] > 0.0: + solref = mix * geom_solref[solref_id, g1] + (1.0 - mix) * geom_solref[solref_id, g2] else: solref = wp.min(geom_solref[solref_id, g1], geom_solref[solref_id, g2]) solreffriction = wp.vec2(0.0, 0.0) - solimp = ( - mix * geom_solimp[solimp_id, g1] - + (1.0 - mix) * geom_solimp[solimp_id, g2] - ) + solimp = mix * geom_solimp[solimp_id, g1] + (1.0 - mix) * geom_solimp[solimp_id, g2] # geom priority is ignored margin = geom_margin[margin_id, g1] + geom_margin[margin_id, g2] gap = geom_gap[gap_id, g1] + geom_gap[gap_id, g2] friction = vec5( - wp.max(MJ_MINMU, friction[0]), - wp.max(MJ_MINMU, friction[1]), - wp.max(MJ_MINMU, friction[2]), - wp.max(MJ_MINMU, friction[3]), - wp.max(MJ_MINMU, friction[4]), + wp.max(MJ_MINMU, friction[0]), + wp.max(MJ_MINMU, friction[1]), + wp.max(MJ_MINMU, friction[2]), + wp.max(MJ_MINMU, friction[3]), + wp.max(MJ_MINMU, friction[4]), ) return geoms, margin, gap, condim, friction, solref, solreffriction, solimp @@ -366,7 +357,7 @@ class CollisionContext: def create_collision_context(naconmax: int) -> CollisionContext: """Create a CollisionContext with allocated arrays.""" return CollisionContext( - collision_pair=wp.empty(naconmax, dtype=wp.vec2i), - collision_pairid=wp.empty(naconmax, dtype=wp.vec2i), - collision_worldid=wp.empty(naconmax, dtype=int), + collision_pair=wp.empty(naconmax, dtype=wp.vec2i), + collision_pairid=wp.empty(naconmax, dtype=wp.vec2i), + collision_worldid=wp.empty(naconmax, dtype=int), ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_driver.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_driver.py index 36cdaea16d..786ee9b888 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_driver.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_driver.py @@ -15,25 +15,27 @@ from typing import Any +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src.collision_convex import convex_narrowphase from mujoco.mjx.third_party.mujoco_warp._src.collision_core import CollisionContext from mujoco.mjx.third_party.mujoco_warp._src.collision_core import create_collision_context +from mujoco.mjx.third_party.mujoco_warp._src.collision_flex import flex_narrowphase from mujoco.mjx.third_party.mujoco_warp._src.collision_primitive import primitive_narrowphase from mujoco.mjx.third_party.mujoco_warp._src.collision_sdf import sdf_narrowphase from mujoco.mjx.third_party.mujoco_warp._src.math import upper_tri_index +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXVAL from mujoco.mjx.third_party.mujoco_warp._src.types import BroadphaseFilter from mujoco.mjx.third_party.mujoco_warp._src.types import BroadphaseType from mujoco.mjx.third_party.mujoco_warp._src.types import CollisionType from mujoco.mjx.third_party.mujoco_warp._src.types import Data from mujoco.mjx.third_party.mujoco_warp._src.types import DisableBit from mujoco.mjx.third_party.mujoco_warp._src.types import GeomType +from mujoco.mjx.third_party.mujoco_warp._src.types import Model from mujoco.mjx.third_party.mujoco_warp._src.types import mat23 from mujoco.mjx.third_party.mujoco_warp._src.types import mat63 -from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXVAL -from mujoco.mjx.third_party.mujoco_warp._src.types import Model from mujoco.mjx.third_party.mujoco_warp._src.warp_util import cache_kernel from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope -import warp as wp wp.set_module_options({"enable_backward": False}) @@ -290,25 +292,13 @@ def func( # 8: obb aabb_id = worldid % ngeom_aabb if wp.static(ngeom_aabb > 1) else 0 - center1, center2 = ( - geom_aabb[aabb_id, geom1, 0], - geom_aabb[aabb_id, geom2, 0], - ) # kernel_analyzer: ignore - size1, size2 = ( - geom_aabb[aabb_id, geom1, 1], - geom_aabb[aabb_id, geom2, 1], - ) # kernel_analyzer: ignore + center1, center2 = geom_aabb[aabb_id, geom1, 0], geom_aabb[aabb_id, geom2, 0] # kernel_analyzer: ignore + size1, size2 = geom_aabb[aabb_id, geom1, 1], geom_aabb[aabb_id, geom2, 1] # kernel_analyzer: ignore rbound_id = worldid % ngeom_rbound if wp.static(ngeom_rbound > 1) else 0 - rbound1, rbound2 = ( - geom_rbound[rbound_id, geom1], - geom_rbound[rbound_id, geom2], - ) # kernel_analyzer: ignore + rbound1, rbound2 = geom_rbound[rbound_id, geom1], geom_rbound[rbound_id, geom2] # kernel_analyzer: ignore margin_id = worldid % ngeom_margin if wp.static(ngeom_margin > 1) else 0 - margin1, margin2 = ( - geom_margin[margin_id, geom1], - geom_margin[margin_id, geom2], - ) # kernel_analyzer: ignore + margin1, margin2 = geom_margin[margin_id, geom1], geom_margin[margin_id, geom2] # kernel_analyzer: ignore xpos1, xpos2 = geom_xpos_in[worldid, geom1], geom_xpos_in[worldid, geom2] xmat1, xmat2 = geom_xmat_in[worldid, geom1], geom_xmat_in[worldid, geom2] @@ -757,6 +747,9 @@ def _narrowphase(m: Model, d: Data, ctx: CollisionContext): if m.has_sdf_geom: sdf_narrowphase(m, d, ctx) + if m.nflex > 0: + flex_narrowphase(m, d) + @event_scope def collision(m: Model, d: Data): diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_flex.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_flex.py new file mode 100644 index 0000000000..215423e991 --- /dev/null +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_flex.py @@ -0,0 +1,834 @@ +# Copyright 2026 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flex collision detection (geom vs flex triangles).""" + +import warp as wp + +from mujoco.mjx.third_party.mujoco_warp._src import collision_primitive_core +from mujoco.mjx.third_party.mujoco_warp._src.math import make_frame +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXVAL +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINMU +from mujoco.mjx.third_party.mujoco_warp._src.types import Data +from mujoco.mjx.third_party.mujoco_warp._src.types import GeomType +from mujoco.mjx.third_party.mujoco_warp._src.types import Model +from mujoco.mjx.third_party.mujoco_warp._src.types import vec5 +from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope + +wp.set_module_options({"enable_backward": False}) + + +@wp.func +def _write_flex_contact( + # Data in: + naconmax_in: int, + # In: + dist: float, + pos: wp.vec3, + frame: wp.mat33, + margin: float, + condim: int, + friction: vec5, + solref: wp.vec2, + solimp: vec5, + geom: int, + flexid: int, + vertid: int, + worldid: int, + # Data out: + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), + contact_includemargin_out: wp.array(dtype=float), + contact_friction_out: wp.array(dtype=vec5), + contact_solref_out: wp.array(dtype=wp.vec2), + contact_solreffriction_out: wp.array(dtype=wp.vec2), + contact_solimp_out: wp.array(dtype=vec5), + contact_dim_out: wp.array(dtype=int), + contact_geom_out: wp.array(dtype=wp.vec2i), + contact_flex_out: wp.array(dtype=wp.vec2i), + contact_vert_out: wp.array(dtype=wp.vec2i), + contact_worldid_out: wp.array(dtype=int), + contact_type_out: wp.array(dtype=int), + contact_geomcollisionid_out: wp.array(dtype=int), + nacon_out: wp.array(dtype=int), +): + if dist >= margin or dist >= MJ_MAXVAL: + return + + id_ = wp.atomic_add(nacon_out, 0, 1) + if id_ >= naconmax_in: + return + + contact_dist_out[id_] = dist + contact_pos_out[id_] = pos + contact_frame_out[id_] = frame + contact_includemargin_out[id_] = margin + contact_friction_out[id_] = friction + contact_solref_out[id_] = solref + contact_solreffriction_out[id_] = wp.vec2(0.0, 0.0) + contact_solimp_out[id_] = solimp + contact_dim_out[id_] = condim + contact_geom_out[id_] = wp.vec2i(geom, -1) + contact_flex_out[id_] = wp.vec2i(-1, flexid) + contact_vert_out[id_] = wp.vec2i(-1, vertid) + contact_worldid_out[id_] = worldid + contact_type_out[id_] = 1 + contact_geomcollisionid_out[id_] = 0 + + +@wp.func +def _collide_geom_triangle( + # Data in: + naconmax_in: int, + # In: + gtype: int, + pos: wp.vec3, + rot: wp.mat33, + size_val: wp.vec3, + t1: wp.vec3, + t2: wp.vec3, + t3: wp.vec3, + tri_radius: float, + margin: float, + condim: int, + friction: vec5, + solref: wp.vec2, + solimp: vec5, + geomid: int, + flexid: int, + vertex_id: int, + worldid: int, + # Data out: + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), + contact_includemargin_out: wp.array(dtype=float), + contact_friction_out: wp.array(dtype=vec5), + contact_solref_out: wp.array(dtype=wp.vec2), + contact_solreffriction_out: wp.array(dtype=wp.vec2), + contact_solimp_out: wp.array(dtype=vec5), + contact_dim_out: wp.array(dtype=int), + contact_geom_out: wp.array(dtype=wp.vec2i), + contact_flex_out: wp.array(dtype=wp.vec2i), + contact_vert_out: wp.array(dtype=wp.vec2i), + contact_worldid_out: wp.array(dtype=int), + contact_type_out: wp.array(dtype=int), + contact_geomcollisionid_out: wp.array(dtype=int), + nacon_out: wp.array(dtype=int), +): + if gtype == int(GeomType.SPHERE): + sphere_radius = size_val[0] + dist, contact_pos, nrm = collision_primitive_core.sphere_triangle(pos, sphere_radius, t1, t2, t3, tri_radius) + if dist < margin: + _write_flex_contact( + naconmax_in, + dist, + contact_pos, + make_frame(nrm), + margin, + condim, + friction, + solref, + solimp, + geomid, + flexid, + vertex_id, + worldid, + contact_dist_out, + contact_pos_out, + contact_frame_out, + contact_includemargin_out, + contact_friction_out, + contact_solref_out, + contact_solreffriction_out, + contact_solimp_out, + contact_dim_out, + contact_geom_out, + contact_flex_out, + contact_vert_out, + contact_worldid_out, + contact_type_out, + contact_geomcollisionid_out, + nacon_out, + ) + return + + # Capsule, box, cylinder all return up to 2 contacts - compute then share writing code + dists = wp.vec2(collision_primitive_core.MJ_MAXVAL, collision_primitive_core.MJ_MAXVAL) + poss = collision_primitive_core.mat23f(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + nrms = collision_primitive_core.mat23f(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + + if gtype == int(GeomType.CAPSULE): + cap_radius = size_val[0] + cap_half_len = size_val[1] + cap_axis = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) + dists, poss, nrms = collision_primitive_core.capsule_triangle( + pos, cap_axis, cap_radius, cap_half_len, t1, t2, t3, tri_radius + ) + elif gtype == int(GeomType.BOX): + dists, poss, nrms = collision_primitive_core.box_triangle(pos, rot, size_val, t1, t2, t3, tri_radius) + elif gtype == int(GeomType.CYLINDER): + cyl_radius = size_val[0] + cyl_half_height = size_val[1] + cyl_axis = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) + dists, poss, nrms = collision_primitive_core.cylinder_triangle( + pos, cyl_axis, cyl_radius, cyl_half_height, t1, t2, t3, tri_radius + ) + + # Write up to 2 contacts (shared code for capsule/box/cylinder) + if dists[0] < margin: + p1 = wp.vec3(poss[0, 0], poss[0, 1], poss[0, 2]) + n1 = wp.vec3(nrms[0, 0], nrms[0, 1], nrms[0, 2]) + _write_flex_contact( + naconmax_in, + dists[0], + p1, + make_frame(n1), + margin, + condim, + friction, + solref, + solimp, + geomid, + flexid, + vertex_id, + worldid, + contact_dist_out, + contact_pos_out, + contact_frame_out, + contact_includemargin_out, + contact_friction_out, + contact_solref_out, + contact_solreffriction_out, + contact_solimp_out, + contact_dim_out, + contact_geom_out, + contact_flex_out, + contact_vert_out, + contact_worldid_out, + contact_type_out, + contact_geomcollisionid_out, + nacon_out, + ) + if dists[1] < margin: + p2 = wp.vec3(poss[1, 0], poss[1, 1], poss[1, 2]) + n2 = wp.vec3(nrms[1, 0], nrms[1, 1], nrms[1, 2]) + _write_flex_contact( + naconmax_in, + dists[1], + p2, + make_frame(n2), + margin, + condim, + friction, + solref, + solimp, + geomid, + flexid, + vertex_id, + worldid, + contact_dist_out, + contact_pos_out, + contact_frame_out, + contact_includemargin_out, + contact_friction_out, + contact_solref_out, + contact_solreffriction_out, + contact_solimp_out, + contact_dim_out, + contact_geom_out, + contact_flex_out, + contact_vert_out, + contact_worldid_out, + contact_type_out, + contact_geomcollisionid_out, + nacon_out, + ) + + +@wp.kernel +def _flex_plane_narrowphase( + # Model: + ngeom: int, + nflexvert: int, + geom_type: wp.array(dtype=int), + geom_condim: wp.array(dtype=int), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + flex_condim: wp.array(dtype=int), + flex_friction: wp.array(dtype=wp.vec3), + flex_margin: wp.array(dtype=float), + flex_vertadr: wp.array(dtype=int), + flex_radius: wp.array(dtype=float), + flex_vertflexid: wp.array(dtype=int), + # Data in: + geom_xpos_in: wp.array2d(dtype=wp.vec3), + geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), + nworld_in: int, + naconmax_in: int, + # Data out: + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), + contact_includemargin_out: wp.array(dtype=float), + contact_friction_out: wp.array(dtype=vec5), + contact_solref_out: wp.array(dtype=wp.vec2), + contact_solreffriction_out: wp.array(dtype=wp.vec2), + contact_solimp_out: wp.array(dtype=vec5), + contact_dim_out: wp.array(dtype=int), + contact_geom_out: wp.array(dtype=wp.vec2i), + contact_flex_out: wp.array(dtype=wp.vec2i), + contact_vert_out: wp.array(dtype=wp.vec2i), + contact_worldid_out: wp.array(dtype=int), + contact_type_out: wp.array(dtype=int), + contact_geomcollisionid_out: wp.array(dtype=int), + nacon_out: wp.array(dtype=int), +): + worldid, vertid = wp.tid() + + flexid = flex_vertflexid[vertid] + radius = flex_radius[flexid] + flex_margin_val = flex_margin[flexid] + flex_condim_val = flex_condim[flexid] + flex_fric = flex_friction[flexid] + # Convert global vertid to local vertex index within this flex + local_vertid = vertid - flex_vertadr[flexid] + + vert = flexvert_xpos_in[worldid, vertid] + + # TODO: Add a broadphase + for geomid in range(ngeom): + gtype = geom_type[geomid] + if gtype != int(GeomType.PLANE): + continue + + plane_pos = geom_xpos_in[worldid, geomid] + plane_rot = geom_xmat_in[worldid, geomid] + plane_normal = wp.vec3(plane_rot[0, 2], plane_rot[1, 2], plane_rot[2, 2]) + + margin = geom_margin[worldid % geom_margin.shape[0], geomid] + flex_margin_val + + diff = vert - plane_pos + signed_dist = wp.dot(diff, plane_normal) + dist = signed_dist - radius + + if dist < margin: + geom_condim_val = geom_condim[geomid] + condim = wp.max(geom_condim_val, flex_condim_val) + solref = geom_solref[worldid % geom_solref.shape[0], geomid] + solimp = geom_solimp[worldid % geom_solimp.shape[0], geomid] + geom_fric = geom_friction[worldid % geom_friction.shape[0], geomid] + fric0 = wp.max(geom_fric[0], flex_fric[0]) + fric1 = wp.max(geom_fric[1], flex_fric[1]) + fric2 = wp.max(geom_fric[2], flex_fric[2]) + friction = vec5( + wp.max(MJ_MINMU, fric0), + wp.max(MJ_MINMU, fric0), + wp.max(MJ_MINMU, fric1), + wp.max(MJ_MINMU, fric2), + wp.max(MJ_MINMU, fric2), + ) + + contact_pos = vert - plane_normal * (dist * 0.5 + radius) + _write_flex_contact( + naconmax_in, + dist, + contact_pos, + make_frame(plane_normal), + margin, + condim, + friction, + solref, + solimp, + geomid, + flexid, + local_vertid, + worldid, + contact_dist_out, + contact_pos_out, + contact_frame_out, + contact_includemargin_out, + contact_friction_out, + contact_solref_out, + contact_solreffriction_out, + contact_solimp_out, + contact_dim_out, + contact_geom_out, + contact_flex_out, + contact_vert_out, + contact_worldid_out, + contact_type_out, + contact_geomcollisionid_out, + nacon_out, + ) + + +@wp.kernel +def _flex_narrowphase_dim2( + # Model: + ngeom: int, + nflex: int, + geom_type: wp.array(dtype=int), + geom_contype: wp.array(dtype=int), + geom_conaffinity: wp.array(dtype=int), + geom_condim: wp.array(dtype=int), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_size: wp.array2d(dtype=wp.vec3), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + flex_contype: wp.array(dtype=int), + flex_conaffinity: wp.array(dtype=int), + flex_margin: wp.array(dtype=float), + flex_dim: wp.array(dtype=int), + flex_vertadr: wp.array(dtype=int), + flex_elemadr: wp.array(dtype=int), + flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), + flex_elem: wp.array(dtype=int), + flex_radius: wp.array(dtype=float), + # Data in: + geom_xpos_in: wp.array2d(dtype=wp.vec3), + geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), + nworld_in: int, + naconmax_in: int, + # Data out: + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), + contact_includemargin_out: wp.array(dtype=float), + contact_friction_out: wp.array(dtype=vec5), + contact_solref_out: wp.array(dtype=wp.vec2), + contact_solreffriction_out: wp.array(dtype=wp.vec2), + contact_solimp_out: wp.array(dtype=vec5), + contact_dim_out: wp.array(dtype=int), + contact_geom_out: wp.array(dtype=wp.vec2i), + contact_flex_out: wp.array(dtype=wp.vec2i), + contact_vert_out: wp.array(dtype=wp.vec2i), + contact_worldid_out: wp.array(dtype=int), + contact_type_out: wp.array(dtype=int), + contact_geomcollisionid_out: wp.array(dtype=int), + nacon_out: wp.array(dtype=int), +): + worldid, elemid = wp.tid() + + flexid = int(-1) + for i in range(nflex): + if flex_dim[i] != 2: + continue + elem_adr = flex_elemadr[i] + elem_num = flex_elemnum[i] + if elemid >= elem_adr and elemid < elem_adr + elem_num: + flexid = i + break + + if flexid < 0: + return + + vert_adr = flex_vertadr[flexid] + tri_radius = flex_radius[flexid] + tri_margin = flex_margin[flexid] + + elem_data_idx = flex_elemdataadr[flexid] + (elemid - flex_elemadr[flexid]) * 3 + v0_local = flex_elem[elem_data_idx] + v1_local = flex_elem[elem_data_idx + 1] + v2_local = flex_elem[elem_data_idx + 2] + + t1 = flexvert_xpos_in[worldid, vert_adr + v0_local] + t2 = flexvert_xpos_in[worldid, vert_adr + v1_local] + t3 = flexvert_xpos_in[worldid, vert_adr + v2_local] + + # TODO: Add a broadphase + for geomid in range(ngeom): + gtype = geom_type[geomid] + if ( + gtype != int(GeomType.SPHERE) + and gtype != int(GeomType.CAPSULE) + and gtype != int(GeomType.BOX) + and gtype != int(GeomType.CYLINDER) + ): + continue + + g_contype = geom_contype[geomid] + g_conaffinity = geom_conaffinity[geomid] + f_contype = flex_contype[flexid] + f_conaffinity = flex_conaffinity[flexid] + if not ((g_contype & f_conaffinity) or (f_contype & g_conaffinity)): + continue + + geom_margin_val = geom_margin[worldid % geom_margin.shape[0], geomid] + margin = geom_margin_val + tri_margin + + geom_pos = geom_xpos_in[worldid, geomid] + geom_rot = geom_xmat_in[worldid, geomid] + geom_size_val = geom_size[worldid % geom_size.shape[0], geomid] + + condim = geom_condim[geomid] + gf = geom_friction[worldid % geom_friction.shape[0], geomid] + friction = vec5( + wp.max(MJ_MINMU, gf[0]), + wp.max(MJ_MINMU, gf[0]), + wp.max(MJ_MINMU, gf[1]), + wp.max(MJ_MINMU, gf[2]), + wp.max(MJ_MINMU, gf[2]), + ) + solref = geom_solref[worldid % geom_solref.shape[0], geomid] + solimp = geom_solimp[worldid % geom_solimp.shape[0], geomid] + + _collide_geom_triangle( + naconmax_in, + gtype, + geom_pos, + geom_rot, + geom_size_val, + t1, + t2, + t3, + tri_radius, + margin, + condim, + friction, + solref, + solimp, + geomid, + flexid, + v0_local, + worldid, + contact_dist_out, + contact_pos_out, + contact_frame_out, + contact_includemargin_out, + contact_friction_out, + contact_solref_out, + contact_solreffriction_out, + contact_solimp_out, + contact_dim_out, + contact_geom_out, + contact_flex_out, + contact_vert_out, + contact_worldid_out, + contact_type_out, + contact_geomcollisionid_out, + nacon_out, + ) + + +@wp.kernel +def _flex_narrowphase_dim3( + # Model: + ngeom: int, + nflex: int, + geom_type: wp.array(dtype=int), + geom_contype: wp.array(dtype=int), + geom_conaffinity: wp.array(dtype=int), + geom_condim: wp.array(dtype=int), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_size: wp.array2d(dtype=wp.vec3), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + flex_contype: wp.array(dtype=int), + flex_conaffinity: wp.array(dtype=int), + flex_margin: wp.array(dtype=float), + flex_dim: wp.array(dtype=int), + flex_vertadr: wp.array(dtype=int), + flex_shellnum: wp.array(dtype=int), + flex_shelldataadr: wp.array(dtype=int), + flex_shell: wp.array(dtype=int), + flex_radius: wp.array(dtype=float), + # Data in: + geom_xpos_in: wp.array2d(dtype=wp.vec3), + geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), + nworld_in: int, + naconmax_in: int, + # Data out: + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), + contact_includemargin_out: wp.array(dtype=float), + contact_friction_out: wp.array(dtype=vec5), + contact_solref_out: wp.array(dtype=wp.vec2), + contact_solreffriction_out: wp.array(dtype=wp.vec2), + contact_solimp_out: wp.array(dtype=vec5), + contact_dim_out: wp.array(dtype=int), + contact_geom_out: wp.array(dtype=wp.vec2i), + contact_flex_out: wp.array(dtype=wp.vec2i), + contact_vert_out: wp.array(dtype=wp.vec2i), + contact_worldid_out: wp.array(dtype=int), + contact_type_out: wp.array(dtype=int), + contact_geomcollisionid_out: wp.array(dtype=int), + nacon_out: wp.array(dtype=int), +): + worldid, shellid = wp.tid() + + flexid = int(-1) + shell_offset = int(0) + for i in range(nflex): + if flex_dim[i] != 3: + continue + shell_num = flex_shellnum[i] + if shellid >= shell_offset and shellid < shell_offset + shell_num: + flexid = i + break + shell_offset += shell_num + + if flexid < 0: + return + + vert_adr = flex_vertadr[flexid] + tri_radius = flex_radius[flexid] + tri_margin = flex_margin[flexid] + + shell_adr = flex_shelldataadr[flexid] + local_shellid = shellid - shell_offset + shell_data_idx = shell_adr + local_shellid * 3 + + v0_local = flex_shell[shell_data_idx] + v1_local = flex_shell[shell_data_idx + 1] + v2_local = flex_shell[shell_data_idx + 2] + + t1 = flexvert_xpos_in[worldid, vert_adr + v0_local] + t2 = flexvert_xpos_in[worldid, vert_adr + v1_local] + t3 = flexvert_xpos_in[worldid, vert_adr + v2_local] + + # TODO: Add a broadphase + for geomid in range(ngeom): + gtype = geom_type[geomid] + if ( + gtype != int(GeomType.SPHERE) + and gtype != int(GeomType.CAPSULE) + and gtype != int(GeomType.BOX) + and gtype != int(GeomType.CYLINDER) + ): + continue + + g_contype = geom_contype[geomid] + g_conaffinity = geom_conaffinity[geomid] + f_contype = flex_contype[flexid] + f_conaffinity = flex_conaffinity[flexid] + if not ((g_contype & f_conaffinity) or (f_contype & g_conaffinity)): + continue + + geom_margin_val = geom_margin[worldid % geom_margin.shape[0], geomid] + margin = geom_margin_val + tri_margin + + geom_pos = geom_xpos_in[worldid, geomid] + geom_rot = geom_xmat_in[worldid, geomid] + geom_size_val = geom_size[worldid % geom_size.shape[0], geomid] + + condim = geom_condim[geomid] + gf = geom_friction[worldid % geom_friction.shape[0], geomid] + friction = vec5( + wp.max(MJ_MINMU, gf[0]), + wp.max(MJ_MINMU, gf[0]), + wp.max(MJ_MINMU, gf[1]), + wp.max(MJ_MINMU, gf[2]), + wp.max(MJ_MINMU, gf[2]), + ) + solref = geom_solref[worldid % geom_solref.shape[0], geomid] + solimp = geom_solimp[worldid % geom_solimp.shape[0], geomid] + + _collide_geom_triangle( + naconmax_in, + gtype, + geom_pos, + geom_rot, + geom_size_val, + t1, + t2, + t3, + tri_radius, + margin, + condim, + friction, + solref, + solimp, + geomid, + flexid, + v0_local, + worldid, + contact_dist_out, + contact_pos_out, + contact_frame_out, + contact_includemargin_out, + contact_friction_out, + contact_solref_out, + contact_solreffriction_out, + contact_solimp_out, + contact_dim_out, + contact_geom_out, + contact_flex_out, + contact_vert_out, + contact_worldid_out, + contact_type_out, + contact_geomcollisionid_out, + nacon_out, + ) + + +@event_scope +def flex_narrowphase(m: Model, d: Data): + """Runs collision detection between geoms and flex elements.""" + if m.nflex == 0: + return + + wp.launch( + _flex_narrowphase_dim2, + dim=(d.nworld, m.nflexelem), + inputs=[ + m.ngeom, + m.nflex, + m.geom_type, + m.geom_contype, + m.geom_conaffinity, + m.geom_condim, + m.geom_solref, + m.geom_solimp, + m.geom_size, + m.geom_friction, + m.geom_margin, + m.flex_contype, + m.flex_conaffinity, + m.flex_margin, + m.flex_dim, + m.flex_vertadr, + m.flex_elemadr, + m.flex_elemnum, + m.flex_elemdataadr, + m.flex_elem, + m.flex_radius, + d.geom_xpos, + d.geom_xmat, + d.flexvert_xpos, + d.nworld, + d.naconmax, + ], + outputs=[ + d.contact.dist, + d.contact.pos, + d.contact.frame, + d.contact.includemargin, + d.contact.friction, + d.contact.solref, + d.contact.solreffriction, + d.contact.solimp, + d.contact.dim, + d.contact.geom, + d.contact.flex, + d.contact.vert, + d.contact.worldid, + d.contact.type, + d.contact.geomcollisionid, + d.nacon, + ], + ) + + wp.launch( + _flex_narrowphase_dim3, + dim=(d.nworld, m.nflexshelldata // 3), + inputs=[ + m.ngeom, + m.nflex, + m.geom_type, + m.geom_contype, + m.geom_conaffinity, + m.geom_condim, + m.geom_solref, + m.geom_solimp, + m.geom_size, + m.geom_friction, + m.geom_margin, + m.flex_contype, + m.flex_conaffinity, + m.flex_margin, + m.flex_dim, + m.flex_vertadr, + m.flex_shellnum, + m.flex_shelldataadr, + m.flex_shell, + m.flex_radius, + d.geom_xpos, + d.geom_xmat, + d.flexvert_xpos, + d.nworld, + d.naconmax, + ], + outputs=[ + d.contact.dist, + d.contact.pos, + d.contact.frame, + d.contact.includemargin, + d.contact.friction, + d.contact.solref, + d.contact.solreffriction, + d.contact.solimp, + d.contact.dim, + d.contact.geom, + d.contact.flex, + d.contact.vert, + d.contact.worldid, + d.contact.type, + d.contact.geomcollisionid, + d.nacon, + ], + ) + + wp.launch( + _flex_plane_narrowphase, + dim=(d.nworld, m.nflexvert), + inputs=[ + m.ngeom, + m.nflexvert, + m.geom_type, + m.geom_condim, + m.geom_solref, + m.geom_solimp, + m.geom_friction, + m.geom_margin, + m.flex_condim, + m.flex_friction, + m.flex_margin, + m.flex_vertadr, + m.flex_radius, + m.flex_vertflexid, + d.geom_xpos, + d.geom_xmat, + d.flexvert_xpos, + d.nworld, + d.naconmax, + ], + outputs=[ + d.contact.dist, + d.contact.pos, + d.contact.frame, + d.contact.includemargin, + d.contact.friction, + d.contact.solref, + d.contact.solreffriction, + d.contact.solimp, + d.contact.dim, + d.contact.geom, + d.contact.flex, + d.contact.vert, + d.contact.worldid, + d.contact.type, + d.contact.geomcollisionid, + d.nacon, + ], + ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_gjk.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_gjk.py index 7e4948ed1a..fe1c4445d3 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_gjk.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_gjk.py @@ -16,11 +16,12 @@ import math from typing import Tuple +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src.collision_core import Geom from mujoco.mjx.third_party.mujoco_warp._src.types import GeomType from mujoco.mjx.third_party.mujoco_warp._src.types import mat43 from mujoco.mjx.third_party.mujoco_warp._src.types import mat63 -import warp as wp # TODO(team): improve compile time to enable backward pass wp.set_module_options({"enable_backward": False}) @@ -581,16 +582,19 @@ def gjk( simplex_index2 = wp.vec4i() n = int(0) coordinates = wp.vec4() # barycentric coordinates - epsilon = wp.where(is_discrete, 0.0, 0.5 * tolerance * tolerance) + tol2 = tolerance * tolerance + epsilon = wp.where(is_discrete, 0.0, 0.5 * tol2) # set initial guess x_k = x1_0 - x2_0 + xnorm_old = FLOAT_MAX - for k in range(gjk_iterations): + for _ in range(gjk_iterations): xnorm = wp.dot(x_k, x_k) # TODO(kbayes): determine new constant here - if xnorm < 1e-12: + if xnorm < tol2 or wp.abs(xnorm_old - xnorm) < tol2: break + xnorm_old = xnorm dir_neg = x_k / wp.sqrt(xnorm) # compute kth support point in geom1 @@ -663,13 +667,6 @@ def gjk( if n == 4: break - if k == gjk_iterations - 1: - wp.printf( - "Warning: opt.ccd_iterations, currently set to %d, needs to be" - " increased.\n", - gjk_iterations, - ) - result = GJKResult() # compute the approximate witness points @@ -1205,7 +1202,6 @@ def _is_invalid_face(face: int) -> bool: def _epa( # In: tolerance: float, - gjk_iterations: int, epa_iterations: int, pt: Polytope, geom1: Geom, @@ -1226,7 +1222,7 @@ def _epa( # so iterations must be cap to limit the number of generated vertices # (one new vertex per iteration) epa_iterations = wp.min(epa_iterations, 1000) - for k in range(epa_iterations): + for _ in range(epa_iterations): pidx = idx idx = int(-1) lower2 = float(FLOAT_MAX) @@ -1325,13 +1321,6 @@ def _epa( # clear horizon pt.nhorizon = 0 - if k == epa_iterations - 1: - wp.printf( - "Warning: opt.ccd_iterations, currently set to %d, needs to be" - " increased.\n", - gjk_iterations, - ) - # return from valid face if idx > -1: x1, x2, dist = _epa_witness(pt, geom1, geom2, geomtype1, geomtype2, idx) @@ -2347,7 +2336,7 @@ def ccd( if pt.status: return result.dist, 1, result.x1, result.x2, -1 - dist, x1, x2, idx = _epa(tolerance, gjk_iterations, epa_iterations, pt, geom1, geom2, geomtype1, geomtype2, is_discrete) + dist, x1, x2, idx = _epa(tolerance, epa_iterations, pt, geom1, geom2, geomtype1, geomtype2, is_discrete) if idx == -1: return FLOAT_MAX, 0, wp.vec3(), wp.vec3(), -1 diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_primitive.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_primitive.py index c7a4051439..f1829de404 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_primitive.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_primitive.py @@ -15,9 +15,11 @@ from typing import Tuple +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src.collision_core import CollisionContext -from mujoco.mjx.third_party.mujoco_warp._src.collision_core import contact_params from mujoco.mjx.third_party.mujoco_warp._src.collision_core import Geom +from mujoco.mjx.third_party.mujoco_warp._src.collision_core import contact_params from mujoco.mjx.third_party.mujoco_warp._src.collision_core import geom_collision_pair from mujoco.mjx.third_party.mujoco_warp._src.collision_core import write_contact from mujoco.mjx.third_party.mujoco_warp._src.collision_primitive_core import box_box @@ -34,15 +36,14 @@ from mujoco.mjx.third_party.mujoco_warp._src.collision_primitive_core import sphere_sphere from mujoco.mjx.third_party.mujoco_warp._src.math import make_frame from mujoco.mjx.third_party.mujoco_warp._src.math import upper_trid_index +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXVAL from mujoco.mjx.third_party.mujoco_warp._src.types import Data from mujoco.mjx.third_party.mujoco_warp._src.types import GeomType -from mujoco.mjx.third_party.mujoco_warp._src.types import mat43 -from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXVAL from mujoco.mjx.third_party.mujoco_warp._src.types import Model +from mujoco.mjx.third_party.mujoco_warp._src.types import mat43 from mujoco.mjx.third_party.mujoco_warp._src.types import vec5 from mujoco.mjx.third_party.mujoco_warp._src.warp_util import cache_kernel from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope -import warp as wp wp.set_module_options({"enable_backward": False}) @@ -304,6 +305,7 @@ def plane_sphere_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -339,6 +341,7 @@ def plane_sphere_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -374,6 +377,7 @@ def sphere_sphere_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -408,6 +412,7 @@ def sphere_sphere_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -443,6 +448,7 @@ def sphere_capsule_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -480,6 +486,7 @@ def sphere_capsule_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -515,6 +522,7 @@ def capsule_capsule_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -564,6 +572,7 @@ def capsule_capsule_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -599,6 +608,7 @@ def plane_capsule_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -644,6 +654,7 @@ def plane_capsule_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -679,6 +690,7 @@ def plane_ellipsoid_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -713,6 +725,7 @@ def plane_ellipsoid_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -748,6 +761,7 @@ def plane_box_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -784,6 +798,7 @@ def plane_box_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -819,6 +834,7 @@ def plane_convex_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -855,6 +871,7 @@ def plane_convex_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -890,6 +907,7 @@ def sphere_cylinder_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -934,6 +952,7 @@ def sphere_cylinder_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -969,6 +988,7 @@ def plane_cylinder_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -1015,6 +1035,7 @@ def plane_cylinder_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -1050,6 +1071,7 @@ def sphere_box_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -1083,6 +1105,7 @@ def sphere_box_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -1118,6 +1141,7 @@ def capsule_box_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -1166,6 +1190,7 @@ def capsule_box_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -1201,6 +1226,7 @@ def box_box_wrapper( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -1245,6 +1271,7 @@ def box_box_wrapper( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -1327,6 +1354,7 @@ def primitive_narrowphase( contact_solimp_out: wp.array(dtype=vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), @@ -1416,6 +1444,7 @@ def primitive_narrowphase( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -1511,6 +1540,7 @@ def primitive_narrowphase(m: Model, d: Data, ctx: CollisionContext, collision_ta d.contact.solimp, d.contact.dim, d.contact.geom, + d.contact.efc_address, d.contact.worldid, d.contact.type, d.contact.geomcollisionid, diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_primitive_core.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_primitive_core.py index 306a301e58..8a85d2412f 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_primitive_core.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_primitive_core.py @@ -1489,3 +1489,507 @@ def capsule_box( mat23f(pos1[0], pos1[1], pos1[2], pos2[0], pos2[1], pos2[2]), mat23f(normal1[0], normal1[1], normal1[2], normal2[0], normal2[1], normal2[2]), ) + + +@wp.func +def _tri_area_sign(p1: wp.vec2, p2: wp.vec2, p3: wp.vec2) -> float: + """Sign of (signed) area of planar triangle.""" + return wp.sign((p1[0] - p3[0]) * (p2[1] - p3[1]) - (p2[0] - p3[0]) * (p1[1] - p3[1])) + + +@wp.func +def _tri_point_segment(p: wp.vec2, u: wp.vec2, v: wp.vec2) -> wp.vec2: + """Find nearest point to p within line segment (u, v).""" + uv = v - u + up = p - u + + denom = wp.max(MJ_MINVAL, wp.dot(uv, uv)) + a = wp.dot(uv, up) / denom + + if a <= 0.0: + return u + elif a >= 1.0: + return v + else: + return u + a * uv + + +@wp.func +def sphere_triangle( + sphere_pos: wp.vec3, + sphere_radius: float, + t1: wp.vec3, + t2: wp.vec3, + t3: wp.vec3, + tri_radius: float, +) -> Tuple[float, wp.vec3, wp.vec3]: + """Core contact geometry calculation for sphere-triangle collision. + + Port of mjraw_SphereTriangle from engine_collision_primitive.c + + Args: + sphere_pos: Center position of the sphere. + sphere_radius: Radius of the sphere. + t1: Triangle vertex positions. + t2: Triangle vertex positions. + t3: Triangle vertex positions. + tri_radius: Triangle (flex element) radius. + + Returns: + - Contact distance (MJ_MAXVAL if no collision). + - Contact position. + - Contact normal vector. + """ + S = sphere_pos - t1 + A = t2 - t1 + B = t3 - t1 + + N = wp.normalize(wp.cross(A, B)) + + dstS = wp.dot(N, S) + + P = S - dstS * N + + V1 = wp.normalize(A) + lenA = wp.length(A) + V2 = wp.normalize(wp.cross(N, A)) + + o = wp.vec2(0.0, 0.0) + a = wp.vec2(lenA, 0.0) + b = wp.vec2(wp.dot(V1, B), wp.dot(V2, B)) + p = wp.vec2(wp.dot(V1, P), wp.dot(V2, P)) + + sign1 = _tri_area_sign(p, o, a) + sign2 = _tri_area_sign(p, a, b) + sign3 = _tri_area_sign(p, b, o) + + X = wp.vec3(0.0) + if sign1 == sign2 and sign2 == sign3: + X = P + else: + x0 = _tri_point_segment(p, o, a) + x1 = _tri_point_segment(p, a, b) + x2 = _tri_point_segment(p, b, o) + + d0 = wp.length(p - x0) + d1 = wp.length(p - x1) + d2 = wp.length(p - x2) + + if d0 < d1 and d0 < d2: + X = x0[0] * V1 + x0[1] * V2 + elif d1 < d2: + X = x1[0] * V1 + x1[1] * V2 + else: + X = x2[0] * V1 + x2[1] * V2 + + nrm = X - S + dst = wp.length(nrm) + + if dst > MJ_MINVAL: + nrm = nrm / dst + else: + nrm = N + + dist = dst - sphere_radius - tri_radius + pos = sphere_pos + nrm * (sphere_radius + 0.5 * dist) + + return dist, pos, nrm + + +@wp.func +def box_triangle( + box_pos: wp.vec3, + box_rot: wp.mat33, + box_size: wp.vec3, + t1: wp.vec3, + t2: wp.vec3, + t3: wp.vec3, + tri_radius: float, +) -> Tuple[wp.vec2, mat23f, mat23f]: + """Core contact geometry calculation for box-triangle collision. + + Port of mjraw_BoxTriangle from engine_collision_primitive.c + + Args: + box_pos: Center position of the box. + box_rot: Orientation matrix of the box. + box_size: Half-sizes of the box. + t1: Triangle vertex positions. + t2: Triangle vertex positions. + t3: Triangle vertex positions. + tri_radius: Triangle (flex element) radius. + + Returns: + - wp.vec2 of distances for up to 2 contacts (MJ_MAXVAL if no collision). + - mat23f of contact positions (2 x vec3). + - mat23f of contact normals (2 x vec3). + """ + dist1 = MJ_MAXVAL + dist2 = MJ_MAXVAL + pos1 = wp.vec3(0.0) + pos2 = wp.vec3(0.0) + nrm1 = wp.vec3(0.0) + nrm2 = wp.vec3(0.0) + cnt = 0 + + box_rotT = wp.transpose(box_rot) + + for vi in range(3): + vert = wp.vec3(0.0) + if vi == 0: + vert = t1 + elif vi == 1: + vert = t2 + else: + vert = t3 + + diff = vert - box_pos + local = box_rotT @ diff + + maxaxis = 0 + maxval = wp.abs(local[0]) - box_size[0] + for j in range(1, 3): + val = wp.abs(local[j]) - box_size[j] + if val > maxval: + maxval = val + maxaxis = j + + inside = True + for j in range(3): + if wp.abs(local[j]) > box_size[j] + tri_radius: + inside = False + + if inside and cnt < 2: + nrm_local = wp.vec3(0.0) + if maxaxis == 0: + nrm_local = wp.vec3(wp.sign(local[0]), 0.0, 0.0) + elif maxaxis == 1: + nrm_local = wp.vec3(0.0, wp.sign(local[1]), 0.0) + else: + nrm_local = wp.vec3(0.0, 0.0, wp.sign(local[2])) + + nrm_global = box_rot @ nrm_local + d = maxval - tri_radius + offset = tri_radius + d * 0.5 + p = vert - nrm_global * offset + + if cnt == 0: + dist1 = d + pos1 = p + nrm1 = nrm_global + else: + dist2 = d + pos2 = p + nrm2 = nrm_global + cnt += 1 + + for i in range(8): + if cnt >= 2: + break + + vec = wp.vec3( + wp.where(i & 1, box_size[0], -box_size[0]), + wp.where(i & 2, box_size[1], -box_size[1]), + wp.where(i & 4, box_size[2], -box_size[2]), + ) + corner = box_rot @ vec + box_pos + + d, p, n = sphere_triangle(corner, 0.0, t1, t2, t3, tri_radius) + if d < MJ_MAXVAL: + if cnt == 0: + dist1 = d + pos1 = p + nrm1 = n + elif cnt == 1: + dist2 = d + pos2 = p + nrm2 = n + cnt += 1 + + return ( + wp.vec2(dist1, dist2), + mat23f(pos1[0], pos1[1], pos1[2], pos2[0], pos2[1], pos2[2]), + mat23f(nrm1[0], nrm1[1], nrm1[2], nrm2[0], nrm2[1], nrm2[2]), + ) + + +@wp.func +def capsule_triangle( + capsule_pos: wp.vec3, + capsule_axis: wp.vec3, + capsule_radius: float, + capsule_half_length: float, + t1: wp.vec3, + t2: wp.vec3, + t3: wp.vec3, + tri_radius: float, +) -> Tuple[wp.vec2, mat23f, mat23f]: + """Core contact geometry calculation for capsule-triangle collision. + + Port of mjraw_CapsuleTriangle from engine_collision_primitive.c + + Args: + capsule_pos: Center position of the capsule. + capsule_axis: Unit axis direction of the capsule. + capsule_radius: Radius of the capsule. + capsule_half_length: Half-length of the capsule cylinder. + t1: Triangle vertex positions. + t2: Triangle vertex positions. + t3: Triangle vertex positions. + tri_radius: Triangle (flex element) radius. + + Returns: + - wp.vec2 of distances for up to 2 contacts (MJ_MAXVAL if no collision). + - mat23f of contact positions (2 x vec3). + - mat23f of contact normals (2 x vec3). + """ + dist1 = MJ_MAXVAL + dist2 = MJ_MAXVAL + pos1 = wp.vec3(0.0) + pos2 = wp.vec3(0.0) + nrm1 = wp.vec3(0.0) + nrm2 = wp.vec3(0.0) + cnt = 0 + + p1 = capsule_pos - capsule_axis * capsule_half_length + p2 = capsule_pos + capsule_axis * capsule_half_length + + d, p, n = sphere_triangle(p1, capsule_radius, t1, t2, t3, tri_radius) + if d < MJ_MAXVAL: + dist1 = d + pos1 = p + nrm1 = n + cnt = 1 + + d, p, n = sphere_triangle(p2, capsule_radius, t1, t2, t3, tri_radius) + if d < MJ_MAXVAL and cnt < 2: + if cnt == 0: + dist1 = d + pos1 = p + nrm1 = n + else: + dist2 = d + pos2 = p + nrm2 = n + cnt += 1 + + ab = p2 - p1 + ab_len_sq = 4.0 * capsule_half_length * capsule_half_length + + for vi in range(3): + if cnt >= 2: + break + + vert = wp.vec3(0.0) + if vi == 0: + vert = t1 + elif vi == 1: + vert = t2 + else: + vert = t3 + + vec = vert - p1 + t_param = wp.dot(vec, ab) / wp.max(MJ_MINVAL, ab_len_sq) + + if t_param > MJ_MINVAL and t_param < 1.0 - MJ_MINVAL: + closest = p1 + ab * t_param + diff = vert - closest + dist_raw = wp.length(diff) + + if dist_raw > MJ_MINVAL: + nrm = diff / dist_raw + d = dist_raw - capsule_radius - tri_radius + p = (closest + vert + nrm * (capsule_radius - tri_radius)) * 0.5 + + if cnt == 0: + dist1 = d + pos1 = p + nrm1 = nrm + else: + dist2 = d + pos2 = p + nrm2 = nrm + cnt += 1 + + return ( + wp.vec2(dist1, dist2), + mat23f(pos1[0], pos1[1], pos1[2], pos2[0], pos2[1], pos2[2]), + mat23f(nrm1[0], nrm1[1], nrm1[2], nrm2[0], nrm2[1], nrm2[2]), + ) + + +@wp.func +def cylinder_triangle( + cylinder_pos: wp.vec3, + cylinder_axis: wp.vec3, + cylinder_radius: float, + cylinder_half_height: float, + t1: wp.vec3, + t2: wp.vec3, + t3: wp.vec3, + tri_radius: float, +) -> Tuple[wp.vec2, mat23f, mat23f]: + """Core contact geometry calculation for cylinder-triangle collision. + + Args: + cylinder_pos: Center position of the cylinder. + cylinder_axis: Unit axis direction of the cylinder. + cylinder_radius: Radius of the cylinder. + cylinder_half_height: Half-height of the cylinder. + t1: Triangle vertex positions. + t2: Triangle vertex positions. + t3: Triangle vertex positions. + tri_radius: Triangle (flex element) radius. + + Returns: + - wp.vec2 of distances for up to 2 contacts (MJ_MAXVAL if no collision). + - mat23f of contact positions (2 x vec3). + - mat23f of contact normals (2 x vec3). + """ + dist1 = MJ_MAXVAL + dist2 = MJ_MAXVAL + pos1 = wp.vec3(0.0) + pos2 = wp.vec3(0.0) + nrm1 = wp.vec3(0.0) + nrm2 = wp.vec3(0.0) + cnt = int(0) + + p1 = cylinder_pos - cylinder_axis * cylinder_half_height + p2 = cylinder_pos + cylinder_axis * cylinder_half_height + + ab = p2 - p1 + ab_len_sq = 4.0 * cylinder_half_height * cylinder_half_height + + for vi in range(3): + if cnt >= 2: + break + + vert = wp.vec3(0.0) + if vi == 0: + vert = t1 + elif vi == 1: + vert = t2 + else: + vert = t3 + + vec = vert - p1 + t_param = wp.dot(vec, ab) / wp.max(MJ_MINVAL, ab_len_sq) + + if t_param > MJ_MINVAL and t_param < 1.0 - MJ_MINVAL: + closest = p1 + ab * t_param + diff = vert - closest + dist_raw = wp.length(diff) + + if dist_raw < cylinder_radius + tri_radius: + if dist_raw > MJ_MINVAL: + nrm = diff / dist_raw + d = dist_raw - cylinder_radius - tri_radius + p = (closest + vert + nrm * (cylinder_radius - tri_radius)) * 0.5 + else: + dist_to_side = cylinder_radius + dist_to_p2 = (1.0 - t_param) * wp.sqrt(ab_len_sq) + dist_to_p1 = t_param * wp.sqrt(ab_len_sq) + + if dist_to_p2 < dist_to_side and dist_to_p2 < dist_to_p1: + nrm = cylinder_axis + d = -dist_to_p2 - tri_radius + p = vert + elif dist_to_p1 < dist_to_side: + nrm = -cylinder_axis + d = -dist_to_p1 - tri_radius + p = vert + else: + tri_normal = wp.normalize(wp.cross(t2 - t1, t3 - t1)) + nrm = tri_normal + d = -cylinder_radius - tri_radius + p = closest + + if cnt == 0: + dist1 = d + pos1 = p + nrm1 = nrm + else: + dist2 = d + pos2 = p + nrm2 = nrm + cnt += 1 + elif t_param <= MJ_MINVAL: + diff = vert - p1 + signed_dist = wp.dot(diff, cylinder_axis) + perp = diff - cylinder_axis * signed_dist + perp_len = wp.length(perp) + + if perp_len < cylinder_radius: + d = -signed_dist - tri_radius + nrm = -cylinder_axis + p = vert - nrm * (tri_radius + d * 0.5) + if cnt == 0: + dist1 = d + pos1 = p + nrm1 = nrm + else: + dist2 = d + pos2 = p + nrm2 = nrm + cnt += 1 + elif perp_len < cylinder_radius + tri_radius: + edge_dir = perp / perp_len + edge_point = p1 + edge_dir * cylinder_radius + diff_to_edge = vert - edge_point + dist_raw = wp.length(diff_to_edge) + if dist_raw > MJ_MINVAL: + nrm = diff_to_edge / dist_raw + d = dist_raw - tri_radius + p = vert - nrm * (tri_radius + d * 0.5) + if cnt == 0: + dist1 = d + pos1 = p + nrm1 = nrm + else: + dist2 = d + pos2 = p + nrm2 = nrm + cnt += 1 + else: + diff = vert - p2 + signed_dist = wp.dot(diff, cylinder_axis) + perp = diff - cylinder_axis * signed_dist + perp_len = wp.length(perp) + + if perp_len < cylinder_radius: + d = signed_dist - tri_radius + nrm = cylinder_axis + p = vert - nrm * (tri_radius + d * 0.5) + if cnt == 0: + dist1 = d + pos1 = p + nrm1 = nrm + else: + dist2 = d + pos2 = p + nrm2 = nrm + cnt += 1 + elif perp_len < cylinder_radius + tri_radius: + edge_dir = perp / perp_len + edge_point = p2 + edge_dir * cylinder_radius + diff_to_edge = vert - edge_point + dist_raw = wp.length(diff_to_edge) + if dist_raw > MJ_MINVAL: + nrm = diff_to_edge / dist_raw + d = dist_raw - tri_radius + p = vert - nrm * (tri_radius + d * 0.5) + if cnt == 0: + dist1 = d + pos1 = p + nrm1 = nrm + else: + dist2 = d + pos2 = p + nrm2 = nrm + cnt += 1 + + return ( + wp.vec2(dist1, dist2), + mat23f(pos1[0], pos1[1], pos1[2], pos2[0], pos2[1], pos2[2]), + mat23f(nrm1[0], nrm1[1], nrm1[2], nrm2[0], nrm2[1], nrm2[2]), + ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_sdf.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_sdf.py index c84b67c820..4decf33c99 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_sdf.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_sdf.py @@ -15,6 +15,8 @@ from typing import Tuple +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src.collision_core import CollisionContext from mujoco.mjx.third_party.mujoco_warp._src.collision_core import contact_params from mujoco.mjx.third_party.mujoco_warp._src.collision_core import geom_collision_pair @@ -27,9 +29,9 @@ from mujoco.mjx.third_party.mujoco_warp._src.types import vec5 from mujoco.mjx.third_party.mujoco_warp._src.types import vec8 from mujoco.mjx.third_party.mujoco_warp._src.types import vec8i +from mujoco.mjx.third_party.mujoco_warp._src.types import vec_pluginattr from mujoco.mjx.third_party.mujoco_warp._src.util_misc import halton from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope -import warp as wp wp.set_module_options({"enable_backward": False}) @@ -38,8 +40,8 @@ class OptimizationParams: rel_mat: wp.mat33 rel_pos: wp.vec3 - attr1: wp.vec3 - attr2: wp.vec3 + attr1: vec_pluginattr + attr2: vec_pluginattr @wp.struct @@ -77,20 +79,24 @@ class MeshData: @wp.func def get_sdf_params( - # Model: - oct_child: wp.array(dtype=vec8i), - oct_aabb: wp.array2d(dtype=wp.vec3), - oct_coeff: wp.array(dtype=vec8), - mesh_octadr: wp.array(dtype=int), - plugin: wp.array(dtype=int), - plugin_attr: wp.array(dtype=wp.vec3f), - # In: - g_type: int, - g_size: wp.vec3, - plugin_id: int, - mesh_id: int, -) -> Tuple[wp.vec3, int, VolumeData, MeshData]: - attributes = g_size + # Model: + oct_child: wp.array(dtype=vec8i), + oct_aabb: wp.array2d(dtype=wp.vec3), + oct_coeff: wp.array(dtype=vec8), + mesh_octadr: wp.array(dtype=int), + plugin: wp.array(dtype=int), + plugin_attr: wp.array(dtype=vec_pluginattr), + # In: + g_type: int, + g_size: wp.vec3, + plugin_id: int, + mesh_id: int, +) -> Tuple[vec_pluginattr, int, VolumeData, MeshData]: + # default attributes from geom size, first 3 values copied + attributes = vec_pluginattr() + attributes[0] = g_size[0] + attributes[1] = g_size[1] + attributes[2] = g_size[2] plugin_index = -1 volume_data = VolumeData() @@ -108,6 +114,16 @@ def get_sdf_params( volume_data.oct_coeff = oct_coeff volume_data.valid = True + elif g_type == GeomType.MESH and mesh_id != -1 and mesh_octadr[mesh_id] != -1: + octadr = mesh_octadr[mesh_id] + volume_data.center = oct_aabb[octadr, 0] + volume_data.half_size = oct_aabb[octadr, 1] + volume_data.root = octadr + volume_data.oct_aabb = oct_aabb + volume_data.oct_child = oct_child + volume_data.oct_coeff = oct_coeff + volume_data.valid = True + return attributes, plugin_index, volume_data, MeshData() @@ -215,24 +231,28 @@ def grad_ellipsoid(p: wp.vec3, size: wp.vec3) -> wp.vec3: @wp.func -def user_sdf(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> float: +def user_sdf(p: wp.vec3, attr: vec_pluginattr, sdf_type: int) -> float: + """User-defined SDF function. + + Access attributes via attr[i] where i is the attribute index (0 to _NPLUGINATTR-1). + """ wp.printf("ERROR: user_sdf function must be implemented by user code\n") return 0.0 @wp.func -def user_sdf_grad(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> wp.vec3: +def user_sdf_grad(p: wp.vec3, attr: vec_pluginattr, sdf_type: int) -> wp.vec3: + """User-defined SDF gradient function. + + Access attributes via attr[i] where i is the attribute index (0 to _NPLUGINATTR-1). + """ wp.printf("ERROR: user_sdf_grad function must be implemented by user code\n") return wp.vec3(0.0) @wp.func def find_oct( - oct_child: wp.array(dtype=vec8i), - oct_aabb: wp.array2d(dtype=wp.vec3), - p: wp.vec3, - grad: bool, - root: int, + oct_child: wp.array(dtype=vec8i), oct_aabb: wp.array2d(dtype=wp.vec3), p: wp.vec3, grad: bool, root: int ) -> Tuple[int, Tuple[vec8, vec8, vec8]]: stack = root niter = int(100) @@ -268,14 +288,14 @@ def find_oct( # child indices are relative to root (mesh_octadr offset) child0 = oct_child[node][0] if ( - child0 == -1 - and oct_child[node][1] == -1 - and oct_child[node][2] == -1 - and oct_child[node][3] == -1 - and oct_child[node][4] == -1 - and oct_child[node][5] == -1 - and oct_child[node][6] == -1 - and oct_child[node][7] == -1 + child0 == -1 + and oct_child[node][1] == -1 + and oct_child[node][2] == -1 + and oct_child[node][3] == -1 + and oct_child[node][4] == -1 + and oct_child[node][5] == -1 + and oct_child[node][6] == -1 + and oct_child[node][7] == -1 ): for j in range(8): if not grad: @@ -342,13 +362,7 @@ def box_project(center: wp.vec3, half_size: wp.vec3, xyz: wp.vec3) -> Tuple[floa @wp.func def sample_volume_sdf(xyz: wp.vec3, volume_data: VolumeData) -> float: dist0, point = box_project(volume_data.center, volume_data.half_size, xyz) - node, weights = find_oct( - volume_data.oct_child, - volume_data.oct_aabb, - point, - grad=False, - root=volume_data.root, - ) + node, weights = find_oct(volume_data.oct_child, volume_data.oct_aabb, point, grad=False, root=volume_data.root) return dist0 + wp.dot(weights[0], volume_data.oct_coeff[node]) @@ -365,13 +379,7 @@ def sample_volume_grad(xyz: wp.vec3, volume_data: VolumeData) -> wp.vec3: grad_y = (sample_volume_sdf(xyz + dy, volume_data) - f) / h grad_z = (sample_volume_sdf(xyz + dz, volume_data) - f) / h return wp.vec3(grad_x, grad_y, grad_z) - node, weights = find_oct( - volume_data.oct_child, - volume_data.oct_aabb, - point, - grad=True, - root=volume_data.root, - ) + node, weights = find_oct(volume_data.oct_child, volume_data.oct_aabb, point, grad=True, root=volume_data.root) grad_x = wp.dot(weights[0], volume_data.oct_coeff[node]) grad_y = wp.dot(weights[1], volume_data.oct_coeff[node]) grad_z = wp.dot(weights[2], volume_data.oct_coeff[node]) @@ -379,15 +387,17 @@ def sample_volume_grad(xyz: wp.vec3, volume_data: VolumeData) -> wp.vec3: @wp.func -def sdf(type: int, p: wp.vec3, attr: wp.vec3, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData) -> float: +def sdf(type: int, p: wp.vec3, attr: vec_pluginattr, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData) -> float: + # extract first 3 elements as vec3 for primitive sdf functions + attr_vec3 = wp.vec3(attr[0], attr[1], attr[2]) if type == GeomType.PLANE: return p[2] elif type == GeomType.SPHERE: - return sphere(p, attr) + return sphere(p, attr_vec3) elif type == GeomType.BOX: - return box(p, attr) + return box(p, attr_vec3) elif type == GeomType.ELLIPSOID: - return ellipsoid(p, attr) + return ellipsoid(p, attr_vec3) elif type == GeomType.MESH and mesh_data.valid: mesh_data.pnt = p mesh_data.vec = -wp.normalize(p) @@ -425,21 +435,27 @@ def sdf(type: int, p: wp.vec3, attr: wp.vec3, sdf_type: int, volume_data: Volume return sample_volume_sdf(p, volume_data) else: return user_sdf(p, attr, sdf_type) + elif type == GeomType.MESH and volume_data.valid: + return sample_volume_sdf(p, volume_data) wp.printf("ERROR: SDF type not implemented\n") return 0.0 @wp.func -def sdf_grad(type: int, p: wp.vec3, attr: wp.vec3, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData) -> wp.vec3: +def sdf_grad( + type: int, p: wp.vec3, attr: vec_pluginattr, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData +) -> wp.vec3: + # extract first 3 elements as vec3 for primitive sdf functions + attr_vec3 = wp.vec3(attr[0], attr[1], attr[2]) if type == GeomType.PLANE: grad = wp.vec3(0.0, 0.0, 1.0) return grad elif type == GeomType.SPHERE: return grad_sphere(p) elif type == GeomType.BOX: - return grad_box(p, attr) + return grad_box(p, attr_vec3) elif type == GeomType.ELLIPSOID: - return grad_ellipsoid(p, attr) + return grad_ellipsoid(p, attr_vec3) elif type == GeomType.MESH and mesh_data.valid: mesh_data.pnt = p mesh_data.vec = -wp.normalize(p) @@ -466,6 +482,8 @@ def sdf_grad(type: int, p: wp.vec3, attr: wp.vec3, sdf_type: int, volume_data: V return sample_volume_grad(p, volume_data) else: return user_sdf_grad(p, attr, sdf_type) + elif type == GeomType.MESH and volume_data.valid: + return sample_volume_grad(p, volume_data) wp.printf("ERROR: SDF grad type not implemented\n") return wp.vec3(0.0) @@ -476,8 +494,8 @@ def clearance( type1: int, p1: wp.vec3, p2: wp.vec3, - s1: wp.vec3, - s2: wp.vec3, + s1: vec_pluginattr, + s2: vec_pluginattr, sdf_type1: int, sdf_type2: int, sfd_intersection: bool, @@ -606,8 +624,8 @@ def gradient_descent( # In: type1: int, x0_initial: wp.vec3, - attr1: wp.vec3, - attr2: wp.vec3, + attr1: vec_pluginattr, + attr2: vec_pluginattr, pos1: wp.vec3, rot1: wp.mat33, pos2: wp.vec3, @@ -645,76 +663,77 @@ def gradient_descent( @wp.kernel def _sdf_narrowphase( - # Model: - nmeshface: int, - oct_child: wp.array(dtype=vec8i), - oct_aabb: wp.array2d(dtype=wp.vec3), - oct_coeff: wp.array(dtype=vec8), - geom_type: wp.array(dtype=int), - geom_condim: wp.array(dtype=int), - geom_dataid: wp.array(dtype=int), - geom_priority: wp.array(dtype=int), - geom_solmix: wp.array2d(dtype=float), - geom_solref: wp.array2d(dtype=wp.vec2), - geom_solimp: wp.array2d(dtype=vec5), - geom_size: wp.array2d(dtype=wp.vec3), - geom_aabb: wp.array3d(dtype=wp.vec3), - geom_friction: wp.array2d(dtype=wp.vec3), - geom_margin: wp.array2d(dtype=float), - geom_gap: wp.array2d(dtype=float), - mesh_vertadr: wp.array(dtype=int), - mesh_vertnum: wp.array(dtype=int), - mesh_faceadr: wp.array(dtype=int), - mesh_octadr: wp.array(dtype=int), - mesh_graphadr: wp.array(dtype=int), - mesh_vert: wp.array(dtype=wp.vec3), - mesh_face: wp.array(dtype=wp.vec3i), - mesh_graph: wp.array(dtype=int), - mesh_polynum: wp.array(dtype=int), - mesh_polyadr: wp.array(dtype=int), - mesh_polynormal: wp.array(dtype=wp.vec3), - mesh_polyvertadr: wp.array(dtype=int), - mesh_polyvertnum: wp.array(dtype=int), - mesh_polyvert: wp.array(dtype=int), - mesh_polymapadr: wp.array(dtype=int), - mesh_polymapnum: wp.array(dtype=int), - mesh_polymap: wp.array(dtype=int), - pair_dim: wp.array(dtype=int), - pair_solref: wp.array2d(dtype=wp.vec2), - pair_solreffriction: wp.array2d(dtype=wp.vec2), - pair_solimp: wp.array2d(dtype=vec5), - pair_margin: wp.array2d(dtype=float), - pair_gap: wp.array2d(dtype=float), - pair_friction: wp.array2d(dtype=vec5), - plugin: wp.array(dtype=int), - plugin_attr: wp.array(dtype=wp.vec3f), - geom_plugin_index: wp.array(dtype=int), - # Data in: - geom_xpos_in: wp.array2d(dtype=wp.vec3), - geom_xmat_in: wp.array2d(dtype=wp.mat33), - naconmax_in: int, - ncollision_in: wp.array(dtype=int), - # In: - collision_pair_in: wp.array(dtype=wp.vec2i), - collision_pairid_in: wp.array(dtype=wp.vec2i), - collision_worldid_in: wp.array(dtype=int), - sdf_initpoints: int, - sdf_iterations: int, - # Data out: - contact_dist_out: wp.array(dtype=float), - contact_pos_out: wp.array(dtype=wp.vec3), - contact_frame_out: wp.array(dtype=wp.mat33), - contact_includemargin_out: wp.array(dtype=float), - contact_friction_out: wp.array(dtype=vec5), - contact_solref_out: wp.array(dtype=wp.vec2), - contact_solreffriction_out: wp.array(dtype=wp.vec2), - contact_solimp_out: wp.array(dtype=vec5), - contact_dim_out: wp.array(dtype=int), - contact_geom_out: wp.array(dtype=wp.vec2i), - contact_worldid_out: wp.array(dtype=int), - contact_type_out: wp.array(dtype=int), - contact_geomcollisionid_out: wp.array(dtype=int), - nacon_out: wp.array(dtype=int), + # Model: + nmeshface: int, + oct_child: wp.array(dtype=vec8i), + oct_aabb: wp.array2d(dtype=wp.vec3), + oct_coeff: wp.array(dtype=vec8), + geom_type: wp.array(dtype=int), + geom_condim: wp.array(dtype=int), + geom_dataid: wp.array(dtype=int), + geom_priority: wp.array(dtype=int), + geom_solmix: wp.array2d(dtype=float), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_size: wp.array2d(dtype=wp.vec3), + geom_aabb: wp.array3d(dtype=wp.vec3), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + geom_gap: wp.array2d(dtype=float), + mesh_vertadr: wp.array(dtype=int), + mesh_vertnum: wp.array(dtype=int), + mesh_faceadr: wp.array(dtype=int), + mesh_octadr: wp.array(dtype=int), + mesh_graphadr: wp.array(dtype=int), + mesh_vert: wp.array(dtype=wp.vec3), + mesh_face: wp.array(dtype=wp.vec3i), + mesh_graph: wp.array(dtype=int), + mesh_polynum: wp.array(dtype=int), + mesh_polyadr: wp.array(dtype=int), + mesh_polynormal: wp.array(dtype=wp.vec3), + mesh_polyvertadr: wp.array(dtype=int), + mesh_polyvertnum: wp.array(dtype=int), + mesh_polyvert: wp.array(dtype=int), + mesh_polymapadr: wp.array(dtype=int), + mesh_polymapnum: wp.array(dtype=int), + mesh_polymap: wp.array(dtype=int), + pair_dim: wp.array(dtype=int), + pair_solref: wp.array2d(dtype=wp.vec2), + pair_solreffriction: wp.array2d(dtype=wp.vec2), + pair_solimp: wp.array2d(dtype=vec5), + pair_margin: wp.array2d(dtype=float), + pair_gap: wp.array2d(dtype=float), + pair_friction: wp.array2d(dtype=vec5), + plugin: wp.array(dtype=int), + plugin_attr: wp.array(dtype=vec_pluginattr), + geom_plugin_index: wp.array(dtype=int), + # Data in: + geom_xpos_in: wp.array2d(dtype=wp.vec3), + geom_xmat_in: wp.array2d(dtype=wp.mat33), + naconmax_in: int, + ncollision_in: wp.array(dtype=int), + # In: + collision_pair_in: wp.array(dtype=wp.vec2i), + collision_pairid_in: wp.array(dtype=wp.vec2i), + collision_worldid_in: wp.array(dtype=int), + sdf_initpoints: int, + sdf_iterations: int, + # Data out: + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), + contact_includemargin_out: wp.array(dtype=float), + contact_friction_out: wp.array(dtype=vec5), + contact_solref_out: wp.array(dtype=wp.vec2), + contact_solreffriction_out: wp.array(dtype=wp.vec2), + contact_solimp_out: wp.array(dtype=vec5), + contact_dim_out: wp.array(dtype=int), + contact_geom_out: wp.array(dtype=wp.vec2i), + contact_efc_address_out: wp.array2d(dtype=int), + contact_worldid_out: wp.array(dtype=int), + contact_type_out: wp.array(dtype=int), + contact_geomcollisionid_out: wp.array(dtype=int), + nacon_out: wp.array(dtype=int), ): i, contact_tid = wp.tid() if i >= sdf_initpoints: @@ -799,29 +818,11 @@ def _sdf_narrowphase( rot1 = geom1.rot attr1, g1_plugin_id, volume_data1, mesh_data1 = get_sdf_params( - oct_child, - oct_aabb, - oct_coeff, - mesh_octadr, - plugin, - plugin_attr, - type1, - geom1.size, - g1_plugin, - geom_dataid[g1], + oct_child, oct_aabb, oct_coeff, mesh_octadr, plugin, plugin_attr, type1, geom1.size, g1_plugin, geom_dataid[g1] ) attr2, g2_plugin_id, volume_data2, mesh_data2 = get_sdf_params( - oct_child, - oct_aabb, - oct_coeff, - mesh_octadr, - plugin, - plugin_attr, - type2, - geom2.size, - g2_plugin, - geom_dataid[g2], + oct_child, oct_aabb, oct_coeff, mesh_octadr, plugin, plugin_attr, type2, geom2.size, g2_plugin, geom_dataid[g2] ) mesh_data1.nmeshface = nmeshface @@ -900,6 +901,7 @@ def _sdf_narrowphase( contact_solimp_out, contact_dim_out, contact_geom_out, + contact_efc_address_out, contact_worldid_out, contact_type_out, contact_geomcollisionid_out, @@ -910,76 +912,77 @@ def _sdf_narrowphase( @event_scope def sdf_narrowphase(m: Model, d: Data, ctx: CollisionContext): wp.launch( - _sdf_narrowphase, - dim=(m.opt.sdf_initpoints, d.naconmax), - inputs=[ - m.nmeshface, - m.oct_child, - m.oct_aabb, - m.oct_coeff, - m.geom_type, - m.geom_condim, - m.geom_dataid, - m.geom_priority, - m.geom_solmix, - m.geom_solref, - m.geom_solimp, - m.geom_size, - m.geom_aabb, - m.geom_friction, - m.geom_margin, - m.geom_gap, - m.mesh_vertadr, - m.mesh_vertnum, - m.mesh_faceadr, - m.mesh_octadr, - m.mesh_graphadr, - m.mesh_vert, - m.mesh_face, - m.mesh_graph, - m.mesh_polynum, - m.mesh_polyadr, - m.mesh_polynormal, - m.mesh_polyvertadr, - m.mesh_polyvertnum, - m.mesh_polyvert, - m.mesh_polymapadr, - m.mesh_polymapnum, - m.mesh_polymap, - m.pair_dim, - m.pair_solref, - m.pair_solreffriction, - m.pair_solimp, - m.pair_margin, - m.pair_gap, - m.pair_friction, - m.plugin, - m.plugin_attr, - m.geom_plugin_index, - d.geom_xpos, - d.geom_xmat, - d.naconmax, - d.ncollision, - ctx.collision_pair, - ctx.collision_pairid, - ctx.collision_worldid, - m.opt.sdf_initpoints, - m.opt.sdf_iterations, - ], - outputs=[ - d.contact.dist, - d.contact.pos, - d.contact.frame, - d.contact.includemargin, - d.contact.friction, - d.contact.solref, - d.contact.solreffriction, - d.contact.solimp, - d.contact.dim, - d.contact.geom, - d.contact.worldid, - d.contact.type, - d.contact.geomcollisionid, - d.nacon, - ], + _sdf_narrowphase, + dim=(m.opt.sdf_initpoints, d.naconmax), + inputs=[ + m.nmeshface, + m.oct_child, + m.oct_aabb, + m.oct_coeff, + m.geom_type, + m.geom_condim, + m.geom_dataid, + m.geom_priority, + m.geom_solmix, + m.geom_solref, + m.geom_solimp, + m.geom_size, + m.geom_aabb, + m.geom_friction, + m.geom_margin, + m.geom_gap, + m.mesh_vertadr, + m.mesh_vertnum, + m.mesh_faceadr, + m.mesh_octadr, + m.mesh_graphadr, + m.mesh_vert, + m.mesh_face, + m.mesh_graph, + m.mesh_polynum, + m.mesh_polyadr, + m.mesh_polynormal, + m.mesh_polyvertadr, + m.mesh_polyvertnum, + m.mesh_polyvert, + m.mesh_polymapadr, + m.mesh_polymapnum, + m.mesh_polymap, + m.pair_dim, + m.pair_solref, + m.pair_solreffriction, + m.pair_solimp, + m.pair_margin, + m.pair_gap, + m.pair_friction, + m.plugin, + m.plugin_attr, + m.geom_plugin_index, + d.geom_xpos, + d.geom_xmat, + d.naconmax, + d.ncollision, + ctx.collision_pair, + ctx.collision_pairid, + ctx.collision_worldid, + m.opt.sdf_initpoints, + m.opt.sdf_iterations, + ], + outputs=[ + d.contact.dist, + d.contact.pos, + d.contact.frame, + d.contact.includemargin, + d.contact.friction, + d.contact.solref, + d.contact.solreffriction, + d.contact.solimp, + d.contact.dim, + d.contact.geom, + d.contact.efc_address, + d.contact.worldid, + d.contact.type, + d.contact.geomcollisionid, + d.nacon, + ], ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/constraint.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/constraint.py index 228006e834..eec47583ad 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/constraint.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/constraint.py @@ -13,18 +13,18 @@ # limitations under the License. # ============================================================================== +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src import math from mujoco.mjx.third_party.mujoco_warp._src import support from mujoco.mjx.third_party.mujoco_warp._src import types from mujoco.mjx.third_party.mujoco_warp._src.types import ConstraintType from mujoco.mjx.third_party.mujoco_warp._src.types import ContactType from mujoco.mjx.third_party.mujoco_warp._src.types import DisableBit -from mujoco.mjx.third_party.mujoco_warp._src.types import SPARSE_CONSTRAINT_JACOBIAN -from mujoco.mjx.third_party.mujoco_warp._src.types import vec11 from mujoco.mjx.third_party.mujoco_warp._src.types import vec5 +from mujoco.mjx.third_party.mujoco_warp._src.types import vec11 from mujoco.mjx.third_party.mujoco_warp._src.warp_util import cache_kernel from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope -import warp as wp wp.set_module_options({"enable_backward": False}) @@ -36,6 +36,8 @@ def _zero_constraint_counts( nf_out: wp.array(dtype=int), nl_out: wp.array(dtype=int), nefc_out: wp.array(dtype=int), + # Out: + efc_nnz_out: wp.array(dtype=int), ): worldid = wp.tid() @@ -44,35 +46,36 @@ def _zero_constraint_counts( nf_out[worldid] = 0 nl_out[worldid] = 0 nefc_out[worldid] = 0 + efc_nnz_out[worldid] = 0 @wp.func def _efc_row( - # Model: - opt_disableflags: int, - # In: - worldid: int, - timestep: float, - efcid: int, - pos_aref: float, - pos_imp: float, - invweight: float, - solref: wp.vec2, - solimp: vec5, - margin: float, - vel: float, - frictionloss: float, - type: int, - id: int, - # Out: - type_out: wp.array2d(dtype=int), - id_out: wp.array2d(dtype=int), - pos_out: wp.array2d(dtype=float), - margin_out: wp.array2d(dtype=float), - D_out: wp.array2d(dtype=float), - vel_out: wp.array2d(dtype=float), - aref_out: wp.array2d(dtype=float), - frictionloss_out: wp.array2d(dtype=float), + # Model: + opt_disableflags: int, + # In: + worldid: int, + timestep: float, + efcid: int, + pos_aref: float, + pos_imp: float, + invweight: float, + solref: wp.vec2, + solimp: vec5, + margin: float, + vel: float, + frictionloss: float, + type: int, + id: int, + # Out: + type_out: wp.array2d(dtype=int), + id_out: wp.array2d(dtype=int), + pos_out: wp.array2d(dtype=float), + margin_out: wp.array2d(dtype=float), + D_out: wp.array2d(dtype=float), + vel_out: wp.array2d(dtype=float), + aref_out: wp.array2d(dtype=float), + frictionloss_out: wp.array2d(dtype=float), ): # calculate kbi timeconst = solref[0] @@ -108,9 +111,7 @@ def _efc_row( imp = wp.where(imp_x > 1.0, dmax, imp) # set outputs - D_out[worldid, efcid] = 1.0 / wp.max( - invweight * (1.0 - imp) / imp, types.MJ_MINVAL - ) + D_out[worldid, efcid] = 1.0 / wp.max(invweight * (1.0 - imp) / imp, types.MJ_MINVAL) vel_out[worldid, efcid] = vel aref_out[worldid, efcid] = -k * imp * pos_aref - b * vel pos_out[worldid, efcid] = pos_aref + margin @@ -122,52 +123,55 @@ def _efc_row( @wp.kernel def _equality_connect( - # Model: - nv: int, - nsite: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - body_parentid: wp.array(dtype=int), - body_rootid: wp.array(dtype=int), - body_weldid: wp.array(dtype=int), - body_dofnum: wp.array(dtype=int), - body_dofadr: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=wp.vec2), - dof_bodyid: wp.array(dtype=int), - dof_parentid: wp.array(dtype=int), - site_bodyid: wp.array(dtype=int), - eq_obj1id: wp.array(dtype=int), - eq_obj2id: wp.array(dtype=int), - eq_objtype: wp.array(dtype=int), - eq_solref: wp.array2d(dtype=wp.vec2), - eq_solimp: wp.array2d(dtype=vec5), - eq_data: wp.array2d(dtype=vec11), - is_sparse: bool, - eq_connect_adr: wp.array(dtype=int), - # Data in: - qvel_in: wp.array2d(dtype=float), - eq_active_in: wp.array2d(dtype=bool), - xpos_in: wp.array2d(dtype=wp.vec3), - xmat_in: wp.array2d(dtype=wp.mat33), - site_xpos_in: wp.array2d(dtype=wp.vec3), - subtree_com_in: wp.array2d(dtype=wp.vec3), - cdof_in: wp.array2d(dtype=wp.spatial_vector), - njmax_in: int, - # Data out: - ne_out: wp.array(dtype=int), - nefc_out: wp.array(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + # Model: + nv: int, + nsite: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + body_parentid: wp.array(dtype=int), + body_rootid: wp.array(dtype=int), + body_weldid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), + body_invweight0: wp.array2d(dtype=wp.vec2), + dof_bodyid: wp.array(dtype=int), + dof_parentid: wp.array(dtype=int), + site_bodyid: wp.array(dtype=int), + eq_obj1id: wp.array(dtype=int), + eq_obj2id: wp.array(dtype=int), + eq_objtype: wp.array(dtype=int), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), + eq_data: wp.array2d(dtype=vec11), + is_sparse: bool, + eq_connect_adr: wp.array(dtype=int), + # Data in: + qvel_in: wp.array2d(dtype=float), + eq_active_in: wp.array2d(dtype=bool), + xpos_in: wp.array2d(dtype=wp.vec3), + xmat_in: wp.array2d(dtype=wp.mat33), + site_xpos_in: wp.array2d(dtype=wp.vec3), + subtree_com_in: wp.array2d(dtype=wp.vec3), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + njmax_in: int, + njmax_nnz_in: int, + # Data out: + ne_out: wp.array(dtype=int), + nefc_out: wp.array(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): """Calculates constraint rows for connect equality constraints.""" worldid, eqconnectid = wp.tid() @@ -182,6 +186,10 @@ def _equality_connect( if efcid >= njmax_in - 3: return + efcid0 = efcid + 0 + efcid1 = efcid + 1 + efcid2 = efcid + 2 + data = eq_data[worldid % eq_data.shape[0], eqid] anchor1 = wp.vec3f(data[0], data[1], data[2]) anchor2 = wp.vec3f(data[3], data[4], data[5]) @@ -207,26 +215,39 @@ def _equality_connect( Jqvel = wp.vec3f(0.0, 0.0, 0.0) if is_sparse: + # TODO(team): pre-compute number of non-zeros body1 = body_weldid[body1] body2 = body_weldid[body2] da1 = int(body_dofadr[body1] + body_dofnum[body1] - 1) da2 = int(body_dofadr[body2] + body_dofnum[body2] - 1) - efcid0 = efcid + 0 - efcid1 = efcid + 1 - efcid2 = efcid + 2 - - rowadr0 = efcid0 * nv - rowadr1 = efcid1 * nv - rowadr2 = efcid2 * nv + # count non-zeros + pda1 = da1 + pda2 = da2 + rownnz = int(0) + while pda1 >= 0 or pda2 >= 0: + da = wp.max(pda1, pda2) + if pda1 == da: + pda1 = dof_parentid[pda1] + if pda2 == da: + pda2 = dof_parentid[pda2] + rownnz += 1 - efc_J_rowadr_out[worldid, efcid0] = rowadr0 - efc_J_rowadr_out[worldid, efcid1] = rowadr1 - efc_J_rowadr_out[worldid, efcid2] = rowadr2 + # get rowadr + rowadr = wp.atomic_add(efc_nnz_out, worldid, 3 * rownnz) + if rowadr + 3 * rownnz > njmax_nnz_in: + return + efc_J_rowadr_out[worldid, efcid0] = rowadr + efc_J_rowadr_out[worldid, efcid1] = rowadr + rownnz + efc_J_rowadr_out[worldid, efcid2] = rowadr + 2 * rownnz - rownnz = int(0) + efc_J_rownnz_out[worldid, efcid0] = rownnz + efc_J_rownnz_out[worldid, efcid1] = rownnz + efc_J_rownnz_out[worldid, efcid2] = rownnz + # compute J and colind + nnz = int(0) while da1 >= 0 or da2 >= 0: da = wp.max(da1, da2) if da1 == da: @@ -235,32 +256,32 @@ def _equality_connect( da2 = dof_parentid[da2] jacp1, _ = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - pos1, - body1, - da, - worldid, + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + pos1, + body1, + da, + worldid, ) jacp2, _ = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - pos2, - body2, - da, - worldid, + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + pos2, + body2, + da, + worldid, ) j1mj2 = jacp1 - jacp2 - sparseid0 = rowadr0 + rownnz - sparseid1 = rowadr1 + rownnz - sparseid2 = rowadr2 + rownnz + sparseid0 = rowadr + nnz + sparseid1 = rowadr + rownnz + nnz + sparseid2 = rowadr + 2 * rownnz + nnz efc_J_colind_out[worldid, 0, sparseid0] = da efc_J_colind_out[worldid, 0, sparseid1] = da @@ -272,49 +293,42 @@ def _equality_connect( Jqvel += j1mj2 * qvel_in[worldid, da] - rownnz += 1 - - efc_J_rownnz_out[worldid, efcid0] = rownnz - efc_J_rownnz_out[worldid, efcid1] = rownnz - efc_J_rownnz_out[worldid, efcid2] = rownnz + nnz += 1 else: # TODO(team): dof tree traversal for dofid in range(nv): jacp1, _ = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - pos1, - body1, - dofid, - worldid, + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + pos1, + body1, + dofid, + worldid, ) jacp2, _ = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - pos2, - body2, - dofid, - worldid, + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + pos2, + body2, + dofid, + worldid, ) j1mj2 = jacp1 - jacp2 - efc_J_out[worldid, efcid + 0, dofid] = j1mj2[0] - efc_J_out[worldid, efcid + 1, dofid] = j1mj2[1] - efc_J_out[worldid, efcid + 2, dofid] = j1mj2[2] + efc_J_out[worldid, efcid0, dofid] = j1mj2[0] + efc_J_out[worldid, efcid1, dofid] = j1mj2[1] + efc_J_out[worldid, efcid2, dofid] = j1mj2[2] Jqvel += j1mj2 * qvel_in[worldid, dofid] body_invweight0_id = worldid % body_invweight0.shape[0] - invweight = ( - body_invweight0[body_invweight0_id, body1][0] - + body_invweight0[body_invweight0_id, body2][0] - ) + invweight = body_invweight0[body_invweight0_id, body1][0] + body_invweight0[body_invweight0_id, body2][0] pos_imp = wp.length(pos) solref = eq_solref[worldid % eq_solref.shape[0], eqid] @@ -325,68 +339,71 @@ def _equality_connect( efcidi = efcid + i _efc_row( - opt_disableflags, - worldid, - timestep, - efcidi, - pos[i], - pos_imp, - invweight, - solref, - solimp, - 0.0, - Jqvel[i], - 0.0, - ConstraintType.EQUALITY, - eqid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, + opt_disableflags, + worldid, + timestep, + efcidi, + pos[i], + pos_imp, + invweight, + solref, + solimp, + 0.0, + Jqvel[i], + 0.0, + ConstraintType.EQUALITY, + eqid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, ) @wp.kernel def _equality_joint( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - qpos0: wp.array2d(dtype=float), - jnt_qposadr: wp.array(dtype=int), - jnt_dofadr: wp.array(dtype=int), - dof_invweight0: wp.array2d(dtype=float), - eq_obj1id: wp.array(dtype=int), - eq_obj2id: wp.array(dtype=int), - eq_solref: wp.array2d(dtype=wp.vec2), - eq_solimp: wp.array2d(dtype=vec5), - eq_data: wp.array2d(dtype=vec11), - is_sparse: bool, - eq_jnt_adr: wp.array(dtype=int), - # Data in: - qpos_in: wp.array2d(dtype=float), - qvel_in: wp.array2d(dtype=float), - eq_active_in: wp.array2d(dtype=bool), - njmax_in: int, - # Data out: - ne_out: wp.array(dtype=int), - nefc_out: wp.array(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + # Model: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + qpos0: wp.array2d(dtype=float), + jnt_qposadr: wp.array(dtype=int), + jnt_dofadr: wp.array(dtype=int), + dof_invweight0: wp.array2d(dtype=float), + eq_obj1id: wp.array(dtype=int), + eq_obj2id: wp.array(dtype=int), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), + eq_data: wp.array2d(dtype=vec11), + is_sparse: bool, + eq_jnt_adr: wp.array(dtype=int), + # Data in: + qpos_in: wp.array2d(dtype=float), + qvel_in: wp.array2d(dtype=float), + eq_active_in: wp.array2d(dtype=bool), + njmax_in: int, + njmax_nnz_in: int, + # Data out: + ne_out: wp.array(dtype=int), + nefc_out: wp.array(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): worldid, eqjntid = wp.tid() eqid = eq_jnt_adr[eqjntid] @@ -414,7 +431,9 @@ def _equality_joint( else: rownnz = 1 efc_J_rownnz_out[worldid, efcid] = rownnz - rowadr = efcid * nv + rowadr = wp.atomic_add(efc_nnz_out, worldid, rownnz) + if rowadr + rownnz > njmax_nnz_in: + return efc_J_rowadr_out[worldid, efcid] = rowadr efc_J_colind_out[worldid, 0, rowadr] = dofadr1 efc_J_out[worldid, 0, rowadr] = 1.0 @@ -430,19 +449,12 @@ def _equality_joint( dif = qpos_in[worldid, qposadr2] - qpos0[qpos0_id, qposadr2] # Horner's method for polynomials - rhs = data[0] + dif * ( - data[1] + dif * (data[2] + dif * (data[3] + dif * data[4])) - ) - deriv_2 = data[1] + dif * ( - 2.0 * data[2] + dif * (3.0 * data[3] + dif * 4.0 * data[4]) - ) + rhs = data[0] + dif * (data[1] + dif * (data[2] + dif * (data[3] + dif * data[4]))) + deriv_2 = data[1] + dif * (2.0 * data[2] + dif * (3.0 * data[3] + dif * 4.0 * data[4])) pos = qpos_in[worldid, qposadr1] - qpos0[qpos0_id, qposadr1] - rhs Jqvel = qvel_in[worldid, dofadr1] - qvel_in[worldid, dofadr2] * deriv_2 - invweight = ( - dof_invweight0[dof_invweight0_id, dofadr1] - + dof_invweight0[dof_invweight0_id, dofadr2] - ) + invweight = dof_invweight0[dof_invweight0_id, dofadr1] + dof_invweight0[dof_invweight0_id, dofadr2] if is_sparse: sparseid = rowadr + 1 @@ -458,67 +470,73 @@ def _equality_joint( # Update constraint parameters _efc_row( - opt_disableflags, - worldid, - opt_timestep[worldid % opt_timestep.shape[0]], - efcid, - pos, - pos, - invweight, - eq_solref[worldid % eq_solref.shape[0], eqid], - eq_solimp[worldid % eq_solimp.shape[0], eqid], - 0.0, - Jqvel, - 0.0, - ConstraintType.EQUALITY, - eqid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, + opt_disableflags, + worldid, + opt_timestep[worldid % opt_timestep.shape[0]], + efcid, + pos, + pos, + invweight, + eq_solref[worldid % eq_solref.shape[0], eqid], + eq_solimp[worldid % eq_solimp.shape[0], eqid], + 0.0, + Jqvel, + 0.0, + ConstraintType.EQUALITY, + eqid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, ) @wp.kernel def _equality_tendon( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - eq_obj1id: wp.array(dtype=int), - eq_obj2id: wp.array(dtype=int), - eq_solref: wp.array2d(dtype=wp.vec2), - eq_solimp: wp.array2d(dtype=vec5), - eq_data: wp.array2d(dtype=vec11), - tendon_length0: wp.array2d(dtype=float), - tendon_invweight0: wp.array2d(dtype=float), - is_sparse: bool, - eq_ten_adr: wp.array(dtype=int), - # Data in: - qvel_in: wp.array2d(dtype=float), - eq_active_in: wp.array2d(dtype=bool), - ten_J_in: wp.array3d(dtype=float), - ten_length_in: wp.array2d(dtype=float), - njmax_in: int, - # Data out: - ne_out: wp.array(dtype=int), - nefc_out: wp.array(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + # Model: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + eq_obj1id: wp.array(dtype=int), + eq_obj2id: wp.array(dtype=int), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), + eq_data: wp.array2d(dtype=vec11), + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + tendon_length0: wp.array2d(dtype=float), + tendon_invweight0: wp.array2d(dtype=float), + is_sparse: bool, + eq_ten_adr: wp.array(dtype=int), + # Data in: + qvel_in: wp.array2d(dtype=float), + eq_active_in: wp.array2d(dtype=bool), + ten_J_in: wp.array2d(dtype=float), + ten_length_in: wp.array2d(dtype=float), + njmax_in: int, + njmax_nnz_in: int, + # Data out: + ne_out: wp.array(dtype=int), + nefc_out: wp.array(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): worldid, eqtenid = wp.tid() eqid = eq_ten_adr[eqtenid] @@ -540,89 +558,118 @@ def _equality_tendon( solimp = eq_solimp[worldid % eq_solimp.shape[0], eqid] tendon_length0_id = worldid % tendon_length0.shape[0] tendon_invweight0_id = worldid % tendon_invweight0.shape[0] - pos1 = ( - ten_length_in[worldid, obj1id] - tendon_length0[tendon_length0_id, obj1id] - ) - jac1 = ten_J_in[worldid, obj1id] + pos1 = ten_length_in[worldid, obj1id] - tendon_length0[tendon_length0_id, obj1id] if obj2id > -1: - invweight = ( - tendon_invweight0[tendon_invweight0_id, obj1id] - + tendon_invweight0[tendon_invweight0_id, obj2id] - ) + invweight = tendon_invweight0[tendon_invweight0_id, obj1id] + tendon_invweight0[tendon_invweight0_id, obj2id] - pos2 = ( - ten_length_in[worldid, obj2id] - - tendon_length0[tendon_length0_id, obj2id] - ) - jac2 = ten_J_in[worldid, obj2id] + pos2 = ten_length_in[worldid, obj2id] - tendon_length0[tendon_length0_id, obj2id] dif = pos2 dif2 = dif * dif dif3 = dif2 * dif dif4 = dif3 * dif - pos = pos1 - ( - data[0] - + data[1] * dif - + data[2] * dif2 - + data[3] * dif3 - + data[4] * dif4 - ) - deriv = ( - data[1] - + 2.0 * data[2] * dif - + 3.0 * data[3] * dif2 - + 4.0 * data[4] * dif3 - ) + pos = pos1 - (data[0] + data[1] * dif + data[2] * dif2 + data[3] * dif3 + data[4] * dif4) + deriv = data[1] + 2.0 * data[2] * dif + 3.0 * data[3] * dif2 + 4.0 * data[4] * dif3 else: invweight = tendon_invweight0[tendon_invweight0_id, obj1id] pos = pos1 - data[0] deriv = 0.0 - Jqvel = float(0.0) + rownnz1 = ten_J_rownnz[obj1id] + rowadr1 = ten_J_rowadr[obj1id] + rownnz2 = 0 + rowadr2 = 0 + + if deriv != 0.0: + rownnz2 = ten_J_rownnz[obj2id] + rowadr2 = ten_J_rowadr[obj2id] - # TODO(team): sparse tendon jacobian if is_sparse: - rowadr = efcid * nv - efc_J_rownnz_out[worldid, efcid] = nv + # TODO(team): pre-compute rownnz + # count unique dofs + p1, p2 = int(0), int(0) + rownnz = int(0) + while p1 < rownnz1 or p2 < rownnz2: + col1 = nv + col2 = nv + if p1 < rownnz1: + col1 = ten_J_colind[rowadr1 + p1] + if p2 < rownnz2: + col2 = ten_J_colind[rowadr2 + p2] + if col1 <= col2: + p1 += 1 + if col2 <= col1: + p2 += 1 + rownnz += 1 + + rowadr = wp.atomic_add(efc_nnz_out, worldid, rownnz) + if rowadr + rownnz > njmax_nnz_in: + return efc_J_rowadr_out[worldid, efcid] = rowadr + ptr1 = int(0) + ptr2 = int(0) + + Jqvel = float(0.0) + + nnz = int(0) for i in range(nv): + J1 = float(0.0) + if ptr1 < rownnz1: + sparseid1 = rowadr1 + ptr1 + if ten_J_colind[sparseid1] == i: + J1 = ten_J_in[worldid, sparseid1] + ptr1 += 1 + + J = J1 if deriv != 0.0: - J = jac1[i] + jac2[i] * -deriv - else: - J = jac1[i] + J2 = float(0.0) + if ptr2 < rownnz2: + sparseid2 = rowadr2 + ptr2 + if ten_J_colind[sparseid2] == i: + J2 = ten_J_in[worldid, sparseid2] + ptr2 += 1 + J += J2 * -deriv + if is_sparse: - efc_J_colind_out[worldid, 0, rowadr + i] = i - efc_J_out[worldid, 0, rowadr + i] = J + if J != 0.0: + sparseid = rowadr + nnz + efc_J_colind_out[worldid, 0, sparseid] = i + efc_J_out[worldid, 0, sparseid] = J + nnz += 1 else: efc_J_out[worldid, efcid, i] = J + Jqvel += J * qvel_in[worldid, i] + if is_sparse: + efc_J_rownnz_out[worldid, efcid] = nnz + _efc_row( - opt_disableflags, - worldid, - opt_timestep[worldid % opt_timestep.shape[0]], - efcid, - pos, - pos, - invweight, - solref, - solimp, - 0.0, - Jqvel, - 0.0, - ConstraintType.EQUALITY, - eqid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, + opt_disableflags, + worldid, + opt_timestep[worldid % opt_timestep.shape[0]], + efcid, + pos, + pos, + invweight, + solref, + solimp, + 0.0, + Jqvel, + 0.0, + ConstraintType.EQUALITY, + eqid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, ) @@ -630,41 +677,50 @@ def _equality_tendon( def _equality_flex(is_sparse: bool): @wp.kernel(module="unique", enable_backward=False) def kernel( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - flexedge_length0: wp.array(dtype=float), - flexedge_invweight0: wp.array(dtype=float), - flexedge_J_rownnz: wp.array(dtype=int), - flexedge_J_rowadr: wp.array(dtype=int), - flexedge_J_colind: wp.array(dtype=int), - eq_solref: wp.array2d(dtype=wp.vec2), - eq_solimp: wp.array2d(dtype=vec5), - eq_flex_adr: wp.array(dtype=int), - # Data in: - qvel_in: wp.array2d(dtype=float), - flexedge_J_in: wp.array2d(dtype=float), - flexedge_length_in: wp.array2d(dtype=float), - njmax_in: int, - # Data out: - ne_out: wp.array(dtype=int), - nefc_out: wp.array(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + # Model: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + flex_edgeadr: wp.array(dtype=int), + flex_edgenum: wp.array(dtype=int), + flexedge_length0: wp.array(dtype=float), + flexedge_invweight0: wp.array(dtype=float), + flexedge_J_rownnz: wp.array(dtype=int), + flexedge_J_rowadr: wp.array(dtype=int), + flexedge_J_colind: wp.array(dtype=int), + eq_obj1id: wp.array(dtype=int), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), + eq_flex_adr: wp.array(dtype=int), + # Data in: + qvel_in: wp.array2d(dtype=float), + flexedge_J_in: wp.array2d(dtype=float), + flexedge_length_in: wp.array2d(dtype=float), + njmax_in: int, + njmax_nnz_in: int, + # Data out: + ne_out: wp.array(dtype=int), + nefc_out: wp.array(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): worldid, eqflexid, edgeid = wp.tid() eqid = eq_flex_adr[eqflexid] + flexid = eq_obj1id[eqid] + if edgeid < flex_edgeadr[flexid] or edgeid >= flex_edgeadr[flexid] + flex_edgenum[flexid]: + return wp.atomic_add(ne_out, worldid, 1) efcid = wp.atomic_add(nefc_out, worldid, 1) @@ -683,7 +739,9 @@ def kernel( if wp.static(is_sparse): efc_J_rownnz_out[worldid, efcid] = rownnz - efc_rowadr = efcid * nv + efc_rowadr = wp.atomic_add(efc_nnz_out, worldid, rownnz) + if efc_rowadr + rownnz > njmax_nnz_in: + return efc_J_rowadr_out[worldid, efcid] = efc_rowadr for i in range(rownnz): flex_sparseid = flex_rowadr + i @@ -704,83 +762,86 @@ def kernel( Jqvel += J * qvel_in[worldid, colind] _efc_row( - opt_disableflags, - worldid, - opt_timestep[worldid % opt_timestep.shape[0]], - efcid, - pos, - pos, - flexedge_invweight0[edgeid], - solref, - solimp, - 0.0, - Jqvel, - 0.0, - ConstraintType.EQUALITY, - eqid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, - ) - - return kernel + opt_disableflags, + worldid, + opt_timestep[worldid % opt_timestep.shape[0]], + efcid, + pos, + pos, + flexedge_invweight0[edgeid], + solref, + solimp, + 0.0, + Jqvel, + 0.0, + ConstraintType.EQUALITY, + eqid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, + ) + + return kernel @wp.kernel def _equality_weld( - # Model: - nv: int, - nsite: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - body_parentid: wp.array(dtype=int), - body_rootid: wp.array(dtype=int), - body_weldid: wp.array(dtype=int), - body_dofnum: wp.array(dtype=int), - body_dofadr: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=wp.vec2), - dof_bodyid: wp.array(dtype=int), - dof_parentid: wp.array(dtype=int), - site_bodyid: wp.array(dtype=int), - site_quat: wp.array2d(dtype=wp.quat), - eq_obj1id: wp.array(dtype=int), - eq_obj2id: wp.array(dtype=int), - eq_objtype: wp.array(dtype=int), - eq_solref: wp.array2d(dtype=wp.vec2), - eq_solimp: wp.array2d(dtype=vec5), - eq_data: wp.array2d(dtype=vec11), - is_sparse: bool, - eq_wld_adr: wp.array(dtype=int), - # Data in: - qvel_in: wp.array2d(dtype=float), - eq_active_in: wp.array2d(dtype=bool), - xpos_in: wp.array2d(dtype=wp.vec3), - xquat_in: wp.array2d(dtype=wp.quat), - xmat_in: wp.array2d(dtype=wp.mat33), - site_xpos_in: wp.array2d(dtype=wp.vec3), - subtree_com_in: wp.array2d(dtype=wp.vec3), - cdof_in: wp.array2d(dtype=wp.spatial_vector), - njmax_in: int, - # Data out: - ne_out: wp.array(dtype=int), - nefc_out: wp.array(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + # Model: + nv: int, + nsite: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + body_parentid: wp.array(dtype=int), + body_rootid: wp.array(dtype=int), + body_weldid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), + body_invweight0: wp.array2d(dtype=wp.vec2), + dof_bodyid: wp.array(dtype=int), + dof_parentid: wp.array(dtype=int), + site_bodyid: wp.array(dtype=int), + site_quat: wp.array2d(dtype=wp.quat), + eq_obj1id: wp.array(dtype=int), + eq_obj2id: wp.array(dtype=int), + eq_objtype: wp.array(dtype=int), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), + eq_data: wp.array2d(dtype=vec11), + is_sparse: bool, + eq_wld_adr: wp.array(dtype=int), + # Data in: + qvel_in: wp.array2d(dtype=float), + eq_active_in: wp.array2d(dtype=bool), + xpos_in: wp.array2d(dtype=wp.vec3), + xquat_in: wp.array2d(dtype=wp.quat), + xmat_in: wp.array2d(dtype=wp.mat33), + site_xpos_in: wp.array2d(dtype=wp.vec3), + subtree_com_in: wp.array2d(dtype=wp.vec3), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + njmax_in: int, + njmax_nnz_in: int, + # Data out: + ne_out: wp.array(dtype=int), + nefc_out: wp.array(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): worldid, eqweldid = wp.tid() eqid = eq_wld_adr[eqweldid] @@ -794,6 +855,13 @@ def _equality_weld( if efcid >= njmax_in - 6: return + efcid0 = efcid + 0 + efcid1 = efcid + 1 + efcid2 = efcid + 2 + efcid3 = efcid + 3 + efcid4 = efcid + 4 + efcid5 = efcid + 5 + is_site = eq_objtype[eqid] == types.ObjType.SITE and nsite > 0 obj1id = eq_obj1id[eqid] @@ -812,12 +880,8 @@ def _equality_weld( pos2 = site_xpos_in[worldid, obj2id] site_quat_id = worldid % site_quat.shape[0] - quat = math.mul_quat( - xquat_in[worldid, body1], site_quat[site_quat_id, obj1id] - ) - quat1 = math.quat_inv( - math.mul_quat(xquat_in[worldid, body2], site_quat[site_quat_id, obj2id]) - ) + quat = math.mul_quat(xquat_in[worldid, body1], site_quat[site_quat_id, obj1id]) + quat1 = math.quat_inv(math.mul_quat(xquat_in[worldid, body2], site_quat[site_quat_id, obj2id])) else: body1 = obj1id @@ -833,35 +897,45 @@ def _equality_weld( Jqvelr = wp.vec3f(0.0, 0.0, 0.0) if is_sparse: + # TODO(team): pre-compute number of non-zeros body1 = body_weldid[body1] body2 = body_weldid[body2] da1 = int(body_dofadr[body1] + body_dofnum[body1] - 1) da2 = int(body_dofadr[body2] + body_dofnum[body2] - 1) - efcid0 = efcid + 0 - efcid1 = efcid + 1 - efcid2 = efcid + 2 - efcid3 = efcid + 3 - efcid4 = efcid + 4 - efcid5 = efcid + 5 - - rowadr0 = efcid0 * nv - rowadr1 = efcid1 * nv - rowadr2 = efcid2 * nv - rowadr3 = efcid3 * nv - rowadr4 = efcid4 * nv - rowadr5 = efcid5 * nv - - efc_J_rowadr_out[worldid, efcid0] = rowadr0 - efc_J_rowadr_out[worldid, efcid1] = rowadr1 - efc_J_rowadr_out[worldid, efcid2] = rowadr2 - efc_J_rowadr_out[worldid, efcid3] = rowadr3 - efc_J_rowadr_out[worldid, efcid4] = rowadr4 - efc_J_rowadr_out[worldid, efcid5] = rowadr5 - + # count non-zeros + pda1 = da1 + pda2 = da2 rownnz = int(0) + while pda1 >= 0 or pda2 >= 0: + da = wp.max(pda1, pda2) + if pda1 == da: + pda1 = dof_parentid[da] + if pda2 == da: + pda2 = dof_parentid[da] + rownnz += 1 + + # get rowadr + rowadr = wp.atomic_add(efc_nnz_out, worldid, 6 * rownnz) + if rowadr + 6 * rownnz > njmax_nnz_in: + return + efc_J_rowadr_out[worldid, efcid0] = rowadr + efc_J_rowadr_out[worldid, efcid1] = rowadr + rownnz + efc_J_rowadr_out[worldid, efcid2] = rowadr + 2 * rownnz + efc_J_rowadr_out[worldid, efcid3] = rowadr + 3 * rownnz + efc_J_rowadr_out[worldid, efcid4] = rowadr + 4 * rownnz + efc_J_rowadr_out[worldid, efcid5] = rowadr + 5 * rownnz + + efc_J_rownnz_out[worldid, efcid0] = rownnz + efc_J_rownnz_out[worldid, efcid1] = rownnz + efc_J_rownnz_out[worldid, efcid2] = rownnz + efc_J_rownnz_out[worldid, efcid3] = rownnz + efc_J_rownnz_out[worldid, efcid4] = rownnz + efc_J_rownnz_out[worldid, efcid5] = rownnz + # compute J and colind + nnz = int(0) while da1 >= 0 or da2 >= 0: da = wp.max(da1, da2) if da1 == da: @@ -870,26 +944,26 @@ def _equality_weld( da2 = dof_parentid[da] jacp1, jacr1 = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - pos1, - body1, - da, - worldid, + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + pos1, + body1, + da, + worldid, ) jacp2, jacr2 = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - pos2, - body2, - da, - worldid, + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + pos2, + body2, + da, + worldid, ) jacdifp = jacp1 - jacp2 @@ -898,12 +972,12 @@ def _equality_weld( jacdifrq = math.mul_quat(math.quat_mul_axis(quat1, jacdifr), quat) jacdifr = 0.5 * wp.vec3(jacdifrq[1], jacdifrq[2], jacdifrq[3]) - sparseid0 = rowadr0 + rownnz - sparseid1 = rowadr1 + rownnz - sparseid2 = rowadr2 + rownnz - sparseid3 = rowadr3 + rownnz - sparseid4 = rowadr4 + rownnz - sparseid5 = rowadr5 + rownnz + sparseid0 = rowadr + nnz + sparseid1 = rowadr + rownnz + nnz + sparseid2 = rowadr + 2 * rownnz + nnz + sparseid3 = rowadr + 3 * rownnz + nnz + sparseid4 = rowadr + 4 * rownnz + nnz + sparseid5 = rowadr + 5 * rownnz + nnz efc_J_colind_out[worldid, 0, sparseid0] = da efc_J_colind_out[worldid, 0, sparseid1] = da @@ -922,50 +996,45 @@ def _equality_weld( Jqvelp += jacdifp * qvel_in[worldid, da] Jqvelr += jacdifr * qvel_in[worldid, da] - rownnz += 1 - - efc_J_rownnz_out[worldid, efcid0] = rownnz - efc_J_rownnz_out[worldid, efcid1] = rownnz - efc_J_rownnz_out[worldid, efcid2] = rownnz - efc_J_rownnz_out[worldid, efcid3] = rownnz - efc_J_rownnz_out[worldid, efcid4] = rownnz - efc_J_rownnz_out[worldid, efcid5] = rownnz + nnz += 1 else: for dofid in range(nv): jacp1, jacr1 = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - pos1, - body1, - dofid, - worldid, + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + pos1, + body1, + dofid, + worldid, ) jacp2, jacr2 = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - pos2, - body2, - dofid, - worldid, + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + pos2, + body2, + dofid, + worldid, ) jacdifp = jacp1 - jacp2 - for i in range(3): - efc_J_out[worldid, efcid + i, dofid] = jacdifp[i] + efc_J_out[worldid, efcid0, dofid] = jacdifp[0] + efc_J_out[worldid, efcid1, dofid] = jacdifp[1] + efc_J_out[worldid, efcid2, dofid] = jacdifp[2] jacdifr = (jacr1 - jacr2) * torquescale jacdifrq = math.mul_quat(math.quat_mul_axis(quat1, jacdifr), quat) jacdifr = 0.5 * wp.vec3(jacdifrq[1], jacdifrq[2], jacdifrq[3]) - for i in range(3): - efc_J_out[worldid, efcid + 3 + i, dofid] = jacdifr[i] + efc_J_out[worldid, efcid3, dofid] = jacdifr[0] + efc_J_out[worldid, efcid4, dofid] = jacdifr[1] + efc_J_out[worldid, efcid5, dofid] = jacdifr[2] Jqvelp += jacdifp * qvel_in[worldid, dofid] Jqvelr += jacdifr * qvel_in[worldid, dofid] @@ -977,10 +1046,7 @@ def _equality_weld( crot = wp.vec3(crotq[1], crotq[2], crotq[3]) * torquescale body_invweight0_id = worldid % body_invweight0.shape[0] - invweight_t = ( - body_invweight0[body_invweight0_id, body1][0] - + body_invweight0[body_invweight0_id, body2][0] - ) + invweight_t = body_invweight0[body_invweight0_id, body1][0] + body_invweight0[body_invweight0_id, body2][0] pos_imp = wp.sqrt(wp.length_sq(cpos) + wp.length_sq(crot)) @@ -991,91 +1057,91 @@ def _equality_weld( for i in range(3): _efc_row( - opt_disableflags, - worldid, - timestep, - efcid + i, - cpos[i], - pos_imp, - invweight_t, - solref, - solimp, - 0.0, - Jqvelp[i], - 0.0, - ConstraintType.EQUALITY, - eqid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, + opt_disableflags, + worldid, + timestep, + efcid + i, + cpos[i], + pos_imp, + invweight_t, + solref, + solimp, + 0.0, + Jqvelp[i], + 0.0, + ConstraintType.EQUALITY, + eqid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, ) - invweight_r = ( - body_invweight0[body_invweight0_id, body1][1] - + body_invweight0[body_invweight0_id, body2][1] - ) + invweight_r = body_invweight0[body_invweight0_id, body1][1] + body_invweight0[body_invweight0_id, body2][1] for i in range(3): _efc_row( - opt_disableflags, - worldid, - timestep, - efcid + 3 + i, - crot[i], - pos_imp, - invweight_r, - solref, - solimp, - 0.0, - Jqvelr[i], - 0.0, - ConstraintType.EQUALITY, - eqid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, + opt_disableflags, + worldid, + timestep, + efcid + 3 + i, + crot[i], + pos_imp, + invweight_r, + solref, + solimp, + 0.0, + Jqvelr[i], + 0.0, + ConstraintType.EQUALITY, + eqid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, ) @wp.kernel def _friction_dof( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - dof_solref: wp.array2d(dtype=wp.vec2), - dof_solimp: wp.array2d(dtype=vec5), - dof_frictionloss: wp.array2d(dtype=float), - dof_invweight0: wp.array2d(dtype=float), - is_sparse: bool, - # Data in: - qvel_in: wp.array2d(dtype=float), - njmax_in: int, - # Data out: - nf_out: wp.array(dtype=int), - nefc_out: wp.array(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + # Model: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + dof_solref: wp.array2d(dtype=wp.vec2), + dof_solimp: wp.array2d(dtype=vec5), + dof_frictionloss: wp.array2d(dtype=float), + dof_invweight0: wp.array2d(dtype=float), + is_sparse: bool, + # Data in: + qvel_in: wp.array2d(dtype=float), + njmax_in: int, + njmax_nnz_in: int, + # Data out: + nf_out: wp.array(dtype=int), + nefc_out: wp.array(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): worldid, dofid = wp.tid() @@ -1092,7 +1158,9 @@ def _friction_dof( if is_sparse: efc_J_rownnz_out[worldid, efcid] = 1 - rowadr = efcid * nv + rowadr = wp.atomic_add(efc_nnz_out, worldid, 1) + if rowadr + 1 > njmax_nnz_in: + return efc_J_rowadr_out[worldid, efcid] = rowadr efc_J_colind_out[worldid, 0, rowadr] = dofid efc_J_out[worldid, 0, rowadr] = 1.0 @@ -1107,61 +1175,67 @@ def _friction_dof( dof_solref_id = worldid % dof_solref.shape[0] dof_solimp_id = worldid % dof_solimp.shape[0] _efc_row( - opt_disableflags, - worldid, - opt_timestep[worldid % opt_timestep.shape[0]], - efcid, - 0.0, - 0.0, - dof_invweight0[dof_invweight0_id, dofid], - dof_solref[dof_solref_id, dofid], - dof_solimp[dof_solimp_id, dofid], - 0.0, - Jqvel, - dof_frictionloss[dof_frictionloss_id, dofid], - ConstraintType.FRICTION_DOF, - dofid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, + opt_disableflags, + worldid, + opt_timestep[worldid % opt_timestep.shape[0]], + efcid, + 0.0, + 0.0, + dof_invweight0[dof_invweight0_id, dofid], + dof_solref[dof_solref_id, dofid], + dof_solimp[dof_solimp_id, dofid], + 0.0, + Jqvel, + dof_frictionloss[dof_frictionloss_id, dofid], + ConstraintType.FRICTION_DOF, + dofid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, ) @wp.kernel def _friction_tendon( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - tendon_solref_fri: wp.array2d(dtype=wp.vec2), - tendon_solimp_fri: wp.array2d(dtype=vec5), - tendon_frictionloss: wp.array2d(dtype=float), - tendon_invweight0: wp.array2d(dtype=float), - is_sparse: bool, - # Data in: - qvel_in: wp.array2d(dtype=float), - ten_J_in: wp.array3d(dtype=float), - njmax_in: int, - # Data out: - nf_out: wp.array(dtype=int), - nefc_out: wp.array(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + # Model: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + tendon_solref_fri: wp.array2d(dtype=wp.vec2), + tendon_solimp_fri: wp.array2d(dtype=vec5), + tendon_frictionloss: wp.array2d(dtype=float), + tendon_invweight0: wp.array2d(dtype=float), + is_sparse: bool, + # Data in: + qvel_in: wp.array2d(dtype=float), + ten_J_in: wp.array2d(dtype=float), + njmax_in: int, + njmax_nnz_in: int, + # Data out: + nf_out: wp.array(dtype=int), + nefc_out: wp.array(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): worldid, tenid = wp.tid() @@ -1179,86 +1253,103 @@ def _friction_tendon( Jqvel = float(0.0) - # TODO(team): sparse tendon jacobian + rownnz_tenJ = ten_J_rownnz[tenid] + rowadr_tenJ = ten_J_rowadr[tenid] if is_sparse: - rowadr = efcid * nv - efc_J_rownnz_out[worldid, efcid] = nv - efc_J_rowadr_out[worldid, efcid] = rowadr - - for i in range(nv): - # TODO(team): sparse ten_J - J = ten_J_in[worldid, tenid, i] - if is_sparse: - efc_J_colind_out[worldid, 0, rowadr + i] = i - efc_J_out[worldid, 0, rowadr + i] = J - else: - efc_J_out[worldid, efcid, i] = J - - Jqvel += J * qvel_in[worldid, i] + efc_J_rownnz_out[worldid, efcid] = rownnz_tenJ + rowadr_efc = wp.atomic_add(efc_nnz_out, worldid, rownnz_tenJ) + if rowadr_efc + rownnz_tenJ > njmax_nnz_in: + return + efc_J_rowadr_out[worldid, efcid] = rowadr_efc + + for i in range(rownnz_tenJ): + sparseid_ten = rowadr_tenJ + i + sparseid_efc = rowadr_efc + i + colind = ten_J_colind[sparseid_ten] + J = ten_J_in[worldid, sparseid_ten] + efc_J_colind_out[worldid, 0, sparseid_efc] = colind + efc_J_out[worldid, 0, sparseid_efc] = J + Jqvel += J * qvel_in[worldid, colind] + else: + nnz = int(0) + colind = ten_J_colind[rowadr_tenJ] + for i in range(nv): + if nnz < rownnz_tenJ and i == colind: + J = ten_J_in[worldid, rowadr_tenJ + nnz] + efc_J_out[worldid, efcid, i] = J + Jqvel += J * qvel_in[worldid, i] + nnz += 1 + if nnz < rownnz_tenJ: + colind = ten_J_colind[rowadr_tenJ + nnz] + else: + efc_J_out[worldid, efcid, i] = 0.0 tendon_invweight0_id = worldid % tendon_invweight0.shape[0] tendon_solref_fri_id = worldid % tendon_solref_fri.shape[0] tendon_solimp_fri_id = worldid % tendon_solimp_fri.shape[0] _efc_row( - opt_disableflags, - worldid, - opt_timestep[worldid % opt_timestep.shape[0]], - efcid, - 0.0, - 0.0, - tendon_invweight0[tendon_invweight0_id, tenid], - tendon_solref_fri[tendon_solref_fri_id, tenid], - tendon_solimp_fri[tendon_solimp_fri_id, tenid], - 0.0, - Jqvel, - frictionloss, - ConstraintType.FRICTION_TENDON, - tenid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, + opt_disableflags, + worldid, + opt_timestep[worldid % opt_timestep.shape[0]], + efcid, + 0.0, + 0.0, + tendon_invweight0[tendon_invweight0_id, tenid], + tendon_solref_fri[tendon_solref_fri_id, tenid], + tendon_solimp_fri[tendon_solimp_fri_id, tenid], + 0.0, + Jqvel, + frictionloss, + ConstraintType.FRICTION_TENDON, + tenid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, ) @wp.kernel def _limit_slide_hinge( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - jnt_qposadr: wp.array(dtype=int), - jnt_dofadr: wp.array(dtype=int), - jnt_solref: wp.array2d(dtype=wp.vec2), - jnt_solimp: wp.array2d(dtype=vec5), - jnt_range: wp.array2d(dtype=wp.vec2), - jnt_margin: wp.array2d(dtype=float), - dof_invweight0: wp.array2d(dtype=float), - is_sparse: bool, - jnt_limited_slide_hinge_adr: wp.array(dtype=int), - # Data in: - qpos_in: wp.array2d(dtype=float), - qvel_in: wp.array2d(dtype=float), - njmax_in: int, - # Data out: - nl_out: wp.array(dtype=int), - nefc_out: wp.array(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + # Model: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + jnt_qposadr: wp.array(dtype=int), + jnt_dofadr: wp.array(dtype=int), + jnt_solref: wp.array2d(dtype=wp.vec2), + jnt_solimp: wp.array2d(dtype=vec5), + jnt_range: wp.array2d(dtype=wp.vec2), + jnt_margin: wp.array2d(dtype=float), + dof_invweight0: wp.array2d(dtype=float), + is_sparse: bool, + jnt_limited_slide_hinge_adr: wp.array(dtype=int), + # Data in: + qpos_in: wp.array2d(dtype=float), + qvel_in: wp.array2d(dtype=float), + njmax_in: int, + njmax_nnz_in: int, + # Data out: + nl_out: wp.array(dtype=int), + nefc_out: wp.array(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): worldid, jntlimitedid = wp.tid() jntid = jnt_limited_slide_hinge_adr[jntlimitedid] @@ -1285,7 +1376,9 @@ def _limit_slide_hinge( if is_sparse: efc_J_rownnz_out[worldid, efcid] = 1 - rowadr = efcid * nv + rowadr = wp.atomic_add(efc_nnz_out, worldid, 1) + if rowadr + 1 > njmax_nnz_in: + return efc_J_rowadr_out[worldid, efcid] = rowadr efc_J_colind_out[worldid, 0, rowadr] = dofadr efc_J_out[worldid, 0, rowadr] = J @@ -1300,65 +1393,68 @@ def _limit_slide_hinge( jnt_solref_id = worldid % jnt_solref.shape[0] jnt_solimp_id = worldid % jnt_solimp.shape[0] _efc_row( - opt_disableflags, - worldid, - opt_timestep[worldid % opt_timestep.shape[0]], - efcid, - pos, - pos, - dof_invweight0[dof_invweight0_id, dofadr], - jnt_solref[jnt_solref_id, jntid], - jnt_solimp[jnt_solimp_id, jntid], - jntmargin, - Jqvel, - 0.0, - ConstraintType.LIMIT_JOINT, - jntid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, - ) - - -@wp.kernel -def _limit_ball( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - jnt_qposadr: wp.array(dtype=int), - jnt_dofadr: wp.array(dtype=int), - jnt_solref: wp.array2d(dtype=wp.vec2), - jnt_solimp: wp.array2d(dtype=vec5), - jnt_range: wp.array2d(dtype=wp.vec2), - jnt_margin: wp.array2d(dtype=float), - dof_invweight0: wp.array2d(dtype=float), - is_sparse: bool, - jnt_limited_ball_adr: wp.array(dtype=int), - # Data in: - qpos_in: wp.array2d(dtype=float), - qvel_in: wp.array2d(dtype=float), - njmax_in: int, - # Data out: - nl_out: wp.array(dtype=int), - nefc_out: wp.array(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + opt_disableflags, + worldid, + opt_timestep[worldid % opt_timestep.shape[0]], + efcid, + pos, + pos, + dof_invweight0[dof_invweight0_id, dofadr], + jnt_solref[jnt_solref_id, jntid], + jnt_solimp[jnt_solimp_id, jntid], + jntmargin, + Jqvel, + 0.0, + ConstraintType.LIMIT_JOINT, + jntid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, + ) + + +@wp.kernel +def _limit_ball( + # Model: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + jnt_qposadr: wp.array(dtype=int), + jnt_dofadr: wp.array(dtype=int), + jnt_solref: wp.array2d(dtype=wp.vec2), + jnt_solimp: wp.array2d(dtype=vec5), + jnt_range: wp.array2d(dtype=wp.vec2), + jnt_margin: wp.array2d(dtype=float), + dof_invweight0: wp.array2d(dtype=float), + is_sparse: bool, + jnt_limited_ball_adr: wp.array(dtype=int), + # Data in: + qpos_in: wp.array2d(dtype=float), + qvel_in: wp.array2d(dtype=float), + njmax_in: int, + njmax_nnz_in: int, + # Data out: + nl_out: wp.array(dtype=int), + nefc_out: wp.array(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): worldid, jntlimitedid = wp.tid() jntid = jnt_limited_ball_adr[jntlimitedid] @@ -1391,7 +1487,9 @@ def _limit_ball( if is_sparse: efc_J_rownnz_out[worldid, efcid] = 3 - rowadr = efcid * nv + rowadr = wp.atomic_add(efc_nnz_out, worldid, 3) + if rowadr + 3 > njmax_nnz_in: + return efc_J_rowadr_out[worldid, efcid] = rowadr sparseid0 = rowadr + 0 @@ -1420,69 +1518,70 @@ def _limit_ball( jnt_solref_id = worldid % jnt_solref.shape[0] jnt_solimp_id = worldid % jnt_solimp.shape[0] _efc_row( - opt_disableflags, - worldid, - opt_timestep[worldid % opt_timestep.shape[0]], - efcid, - pos, - pos, - dof_invweight0[dof_invweight0_id, dofadr], - jnt_solref[jnt_solref_id, jntid], - jnt_solimp[jnt_solimp_id, jntid], - jntmargin, - Jqvel, - 0.0, - ConstraintType.LIMIT_JOINT, - jntid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, + opt_disableflags, + worldid, + opt_timestep[worldid % opt_timestep.shape[0]], + efcid, + pos, + pos, + dof_invweight0[dof_invweight0_id, dofadr], + jnt_solref[jnt_solref_id, jntid], + jnt_solimp[jnt_solimp_id, jntid], + jntmargin, + Jqvel, + 0.0, + ConstraintType.LIMIT_JOINT, + jntid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, ) @wp.kernel def _limit_tendon( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - jnt_dofadr: wp.array(dtype=int), - tendon_adr: wp.array(dtype=int), - tendon_num: wp.array(dtype=int), - tendon_solref_lim: wp.array2d(dtype=wp.vec2), - tendon_solimp_lim: wp.array2d(dtype=vec5), - tendon_range: wp.array2d(dtype=wp.vec2), - tendon_margin: wp.array2d(dtype=float), - tendon_invweight0: wp.array2d(dtype=float), - wrap_type: wp.array(dtype=int), - wrap_objid: wp.array(dtype=int), - is_sparse: bool, - tendon_limited_adr: wp.array(dtype=int), - # Data in: - qvel_in: wp.array2d(dtype=float), - ten_J_in: wp.array3d(dtype=float), - ten_length_in: wp.array2d(dtype=float), - njmax_in: int, - # Data out: - nl_out: wp.array(dtype=int), - nefc_out: wp.array(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + # Model: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + tendon_solref_lim: wp.array2d(dtype=wp.vec2), + tendon_solimp_lim: wp.array2d(dtype=vec5), + tendon_range: wp.array2d(dtype=wp.vec2), + tendon_margin: wp.array2d(dtype=float), + tendon_invweight0: wp.array2d(dtype=float), + is_sparse: bool, + tendon_limited_adr: wp.array(dtype=int), + # Data in: + qvel_in: wp.array2d(dtype=float), + ten_J_in: wp.array2d(dtype=float), + ten_length_in: wp.array2d(dtype=float), + njmax_in: int, + njmax_nnz_in: int, + # Data out: + nl_out: wp.array(dtype=int), + nefc_out: wp.array(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): worldid, tenlimitedid = wp.tid() tenid = tendon_limited_adr[tenlimitedid] @@ -1506,122 +1605,123 @@ def _limit_tendon( Jqvel = float(0.0) scl = float(dist_min < dist_max) * 2.0 - 1.0 - # TODO(team): sparse tendon jacobian + rownnz_tenJ = ten_J_rownnz[tenid] + rowadr_tenJ = ten_J_rowadr[tenid] if is_sparse: - rowadr = efcid * nv - efc_J_rownnz_out[worldid, efcid] = nv - efc_J_rowadr_out[worldid, efcid] = rowadr - for i in range(nv): - efc_J_colind_out[worldid, 0, rowadr + i] = i - efc_J_out[worldid, 0, rowadr + i] = 0.0 - - adr = tendon_adr[tenid] - if wrap_type[adr] == types.WrapType.JOINT: - if not is_sparse: - for i in range(nv): - efc_J_out[worldid, efcid, i] = 0.0 - - ten_num = tendon_num[tenid] - for i in range(ten_num): - dofadr = jnt_dofadr[wrap_objid[adr + i]] - J = scl * ten_J_in[worldid, tenid, dofadr] - - if is_sparse: - efc_J_out[worldid, 0, rowadr + dofadr] = J - else: - efc_J_out[worldid, efcid, dofadr] = J - - Jqvel += J * qvel_in[worldid, dofadr] + efc_J_rownnz_out[worldid, efcid] = rownnz_tenJ + rowadr_efc = wp.atomic_add(efc_nnz_out, worldid, rownnz_tenJ) + if rowadr_efc + rownnz_tenJ > njmax_nnz_in: + return + efc_J_rowadr_out[worldid, efcid] = rowadr_efc + + for i in range(rownnz_tenJ): + sparseid_ten = rowadr_tenJ + i + sparseid_efc = rowadr_efc + i + colind = ten_J_colind[sparseid_ten] + J = scl * ten_J_in[worldid, sparseid_ten] + efc_J_colind_out[worldid, 0, sparseid_efc] = colind + efc_J_out[worldid, 0, sparseid_efc] = J + Jqvel += J * qvel_in[worldid, colind] else: + nnz = int(0) + colind = ten_J_colind[rowadr_tenJ] for i in range(nv): - J = scl * ten_J_in[worldid, tenid, i] - - if is_sparse: - efc_J_out[worldid, 0, rowadr + i] = J - else: + if nnz < rownnz_tenJ and i == colind: + J = scl * ten_J_in[worldid, rowadr_tenJ + nnz] efc_J_out[worldid, efcid, i] = J - - Jqvel += J * qvel_in[worldid, i] + Jqvel += J * qvel_in[worldid, i] + nnz += 1 + if nnz < rownnz_tenJ: + colind = ten_J_colind[rowadr_tenJ + nnz] + else: + efc_J_out[worldid, efcid, i] = 0.0 tendon_invweight0_id = worldid % tendon_invweight0.shape[0] tendon_solref_lim_id = worldid % tendon_solref_lim.shape[0] tendon_solimp_lim_id = worldid % tendon_solimp_lim.shape[0] _efc_row( - opt_disableflags, - worldid, - opt_timestep[worldid % opt_timestep.shape[0]], - efcid, - pos, - pos, - tendon_invweight0[tendon_invweight0_id, tenid], - tendon_solref_lim[tendon_solref_lim_id, tenid], - tendon_solimp_lim[tendon_solimp_lim_id, tenid], - tenmargin, - Jqvel, - 0.0, - ConstraintType.LIMIT_TENDON, - tenid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, + opt_disableflags, + worldid, + opt_timestep[worldid % opt_timestep.shape[0]], + efcid, + pos, + pos, + tendon_invweight0[tendon_invweight0_id, tenid], + tendon_solref_lim[tendon_solref_lim_id, tenid], + tendon_solimp_lim[tendon_solimp_lim_id, tenid], + tenmargin, + Jqvel, + 0.0, + ConstraintType.LIMIT_TENDON, + tenid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, ) @wp.kernel def _contact_pyramidal( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - opt_impratio_invsqrt: wp.array(dtype=float), - body_parentid: wp.array(dtype=int), - body_rootid: wp.array(dtype=int), - body_weldid: wp.array(dtype=int), - body_dofnum: wp.array(dtype=int), - body_dofadr: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=wp.vec2), - dof_bodyid: wp.array(dtype=int), - dof_parentid: wp.array(dtype=int), - geom_bodyid: wp.array(dtype=int), - is_sparse: bool, - # Data in: - qvel_in: wp.array2d(dtype=float), - subtree_com_in: wp.array2d(dtype=wp.vec3), - cdof_in: wp.array2d(dtype=wp.spatial_vector), - njmax_in: int, - nacon_in: wp.array(dtype=int), - # In: - dist_in: wp.array(dtype=float), - condim_in: wp.array(dtype=int), - includemargin_in: wp.array(dtype=float), - worldid_in: wp.array(dtype=int), - geom_in: wp.array(dtype=wp.vec2i), - pos_in: wp.array(dtype=wp.vec3), - frame_in: wp.array(dtype=wp.mat33), - friction_in: wp.array(dtype=vec5), - solref_in: wp.array(dtype=wp.vec2), - solimp_in: wp.array(dtype=vec5), - type_in: wp.array(dtype=int), - # Data out: - nefc_out: wp.array(dtype=int), - contact_efc_address_out: wp.array2d(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + # Model: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + opt_impratio_invsqrt: wp.array(dtype=float), + body_parentid: wp.array(dtype=int), + body_rootid: wp.array(dtype=int), + body_weldid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), + body_invweight0: wp.array2d(dtype=wp.vec2), + dof_bodyid: wp.array(dtype=int), + dof_parentid: wp.array(dtype=int), + geom_bodyid: wp.array(dtype=int), + flex_vertadr: wp.array(dtype=int), + flex_vertbodyid: wp.array(dtype=int), + is_sparse: bool, + # Data in: + qvel_in: wp.array2d(dtype=float), + subtree_com_in: wp.array2d(dtype=wp.vec3), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + njmax_in: int, + njmax_nnz_in: int, + nacon_in: wp.array(dtype=int), + # In: + dist_in: wp.array(dtype=float), + condim_in: wp.array(dtype=int), + includemargin_in: wp.array(dtype=float), + worldid_in: wp.array(dtype=int), + geom_in: wp.array(dtype=wp.vec2i), + flex_in: wp.array(dtype=wp.vec2i), + vert_in: wp.array(dtype=wp.vec2i), + pos_in: wp.array(dtype=wp.vec3), + frame_in: wp.array(dtype=wp.mat33), + friction_in: wp.array(dtype=vec5), + solref_in: wp.array(dtype=wp.vec2), + solimp_in: wp.array(dtype=vec5), + type_in: wp.array(dtype=int), + # Data out: + nefc_out: wp.array(dtype=int), + contact_efc_address_out: wp.array2d(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): conid, dimid = wp.tid() @@ -1655,8 +1755,20 @@ def _contact_pyramidal( contact_efc_address_out[conid, dimid] = efcid geom = geom_in[conid] - body1 = geom_bodyid[geom[0]] - body2 = geom_bodyid[geom[1]] + + if geom[0] >= 0: + body1 = geom_bodyid[geom[0]] + else: + flex = flex_in[conid] + vert = vert_in[conid] + body1 = flex_vertbodyid[flex_vertadr[flex[0]] + vert[0]] + + if geom[1] >= 0: + body2 = geom_bodyid[geom[1]] + else: + flex = flex_in[conid] + vert = vert_in[conid] + body2 = flex_vertbodyid[flex_vertadr[flex[1]] + vert[1]] con_pos = pos_in[conid] frame = frame_in[conid] @@ -1674,29 +1786,48 @@ def _contact_pyramidal( invweight = invweight + fri0 * fri0 * invweight invweight = invweight * 2.0 * fri0 * fri0 * impratio_invsqrt * impratio_invsqrt - if is_sparse: - rowadr = efcid * nv - efc_J_rowadr_out[worldid, efcid] = rowadr - Jqvel = float(0.0) # skip fixed bodies body1 = body_weldid[body1] body2 = body_weldid[body2] - da1 = body_dofadr[body1] + body_dofnum[body1] - 1 - da2 = body_dofadr[body2] + body_dofnum[body2] - 1 - da = wp.max(da1, da2) + da1 = int(body_dofadr[body1] + body_dofnum[body1] - 1) + da2 = int(body_dofadr[body2] + body_dofnum[body2] - 1) if is_sparse: + pda1 = da1 + pda2 = da2 rownnz = int(0) + while pda1 >= 0 or pda2 >= 0: + da = wp.max(pda1, pda2) + # skip common dofs + if pda1 == da and pda2 == da: + break + if pda1 == da: + pda1 = dof_parentid[pda1] + if pda2 == da: + pda2 = dof_parentid[pda2] + rownnz += 1 + + # get rowadr + rowadr = wp.atomic_add(efc_nnz_out, worldid, rownnz) + if rowadr + rownnz > njmax_nnz_in: + return + efc_J_rowadr_out[worldid, efcid] = rowadr + efc_J_rownnz_out[worldid, efcid] = rownnz + + da = wp.max(da1, da2) + + if is_sparse: + nnz = int(0) dofid = int(da) else: dofid = int(nv - 1) while True: if is_sparse: - if da1 < 0 and da2 < 0: + if nnz >= rownnz: break else: if dofid < 0: @@ -1749,13 +1880,15 @@ def _contact_pyramidal( J -= Ji * frii if is_sparse: - sparseid = rowadr + rownnz + sparseid = rowadr + nnz efc_J_colind_out[worldid, 0, sparseid] = dofid efc_J_out[worldid, 0, sparseid] = J - rownnz += 1 + nnz += 1 else: efc_J_out[worldid, efcid, dofid] = J Jqvel += J * qvel_in[worldid, dofid] + if is_sparse and nnz >= rownnz: + break # Advance tree pointers and recompute da for next iteration if da1 == da: @@ -1772,91 +1905,95 @@ def _contact_pyramidal( efc_J_out[worldid, efcid, dofid] = 0.0 dofid -= 1 - if is_sparse: - efc_J_rownnz_out[worldid, efcid] = rownnz - if condim == 1: efc_type = ConstraintType.CONTACT_FRICTIONLESS else: efc_type = ConstraintType.CONTACT_PYRAMIDAL _efc_row( - opt_disableflags, - worldid, - timestep, - efcid, - pos, - pos, - invweight, - solref_in[conid], - solimp_in[conid], - includemargin, - Jqvel, - 0.0, - efc_type, - conid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, + opt_disableflags, + worldid, + timestep, + efcid, + pos, + pos, + invweight, + solref_in[conid], + solimp_in[conid], + includemargin, + Jqvel, + 0.0, + efc_type, + conid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, ) @wp.kernel def _contact_elliptic( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_disableflags: int, - opt_impratio_invsqrt: wp.array(dtype=float), - body_parentid: wp.array(dtype=int), - body_rootid: wp.array(dtype=int), - body_weldid: wp.array(dtype=int), - body_dofnum: wp.array(dtype=int), - body_dofadr: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=wp.vec2), - dof_bodyid: wp.array(dtype=int), - dof_parentid: wp.array(dtype=int), - geom_bodyid: wp.array(dtype=int), - is_sparse: bool, - # Data in: - qvel_in: wp.array2d(dtype=float), - subtree_com_in: wp.array2d(dtype=wp.vec3), - cdof_in: wp.array2d(dtype=wp.spatial_vector), - njmax_in: int, - nacon_in: wp.array(dtype=int), - # In: - dist_in: wp.array(dtype=float), - condim_in: wp.array(dtype=int), - includemargin_in: wp.array(dtype=float), - worldid_in: wp.array(dtype=int), - geom_in: wp.array(dtype=wp.vec2i), - pos_in: wp.array(dtype=wp.vec3), - frame_in: wp.array(dtype=wp.mat33), - friction_in: wp.array(dtype=vec5), - solref_in: wp.array(dtype=wp.vec2), - solreffriction_in: wp.array(dtype=wp.vec2), - solimp_in: wp.array(dtype=vec5), - type_in: wp.array(dtype=int), - # Data out: - nefc_out: wp.array(dtype=int), - contact_efc_address_out: wp.array2d(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_rownnz_out: wp.array2d(dtype=int), - efc_J_rowadr_out: wp.array2d(dtype=int), - efc_J_colind_out: wp.array3d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), + # Model: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + opt_impratio_invsqrt: wp.array(dtype=float), + body_parentid: wp.array(dtype=int), + body_rootid: wp.array(dtype=int), + body_weldid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), + body_invweight0: wp.array2d(dtype=wp.vec2), + dof_bodyid: wp.array(dtype=int), + dof_parentid: wp.array(dtype=int), + geom_bodyid: wp.array(dtype=int), + flex_vertadr: wp.array(dtype=int), + flex_vertbodyid: wp.array(dtype=int), + is_sparse: bool, + # Data in: + qvel_in: wp.array2d(dtype=float), + subtree_com_in: wp.array2d(dtype=wp.vec3), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + njmax_in: int, + njmax_nnz_in: int, + nacon_in: wp.array(dtype=int), + # In: + dist_in: wp.array(dtype=float), + condim_in: wp.array(dtype=int), + includemargin_in: wp.array(dtype=float), + worldid_in: wp.array(dtype=int), + geom_in: wp.array(dtype=wp.vec2i), + flex_in: wp.array(dtype=wp.vec2i), + vert_in: wp.array(dtype=wp.vec2i), + pos_in: wp.array(dtype=wp.vec3), + frame_in: wp.array(dtype=wp.mat33), + friction_in: wp.array(dtype=vec5), + solref_in: wp.array(dtype=wp.vec2), + solreffriction_in: wp.array(dtype=wp.vec2), + solimp_in: wp.array(dtype=vec5), + type_in: wp.array(dtype=int), + # Data out: + nefc_out: wp.array(dtype=int), + contact_efc_address_out: wp.array2d(dtype=int), + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_J_rownnz_out: wp.array2d(dtype=int), + efc_J_rowadr_out: wp.array2d(dtype=int), + efc_J_colind_out: wp.array3d(dtype=int), + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + # Out: + efc_nnz_out: wp.array(dtype=int), ): conid, dimid = wp.tid() @@ -1888,35 +2025,67 @@ def _contact_elliptic( contact_efc_address_out[conid, dimid] = efcid geom = geom_in[conid] - body1 = geom_bodyid[geom[0]] - body2 = geom_bodyid[geom[1]] + + if geom[0] >= 0: + body1 = geom_bodyid[geom[0]] + else: + flex = flex_in[conid] + vert = vert_in[conid] + body1 = flex_vertbodyid[flex_vertadr[flex[0]] + vert[0]] + + if geom[1] >= 0: + body2 = geom_bodyid[geom[1]] + else: + flex = flex_in[conid] + vert = vert_in[conid] + body2 = flex_vertbodyid[flex_vertadr[flex[1]] + vert[1]] con_pos = pos_in[conid] frame = frame_in[conid] - if is_sparse: - rowadr = efcid * nv - efc_J_rowadr_out[worldid, efcid] = rowadr - Jqvel = float(0.0) # skip fixed bodies body1 = body_weldid[body1] body2 = body_weldid[body2] - da1 = body_dofadr[body1] + body_dofnum[body1] - 1 - da2 = body_dofadr[body2] + body_dofnum[body2] - 1 - da = wp.max(da1, da2) + da1 = int(body_dofadr[body1] + body_dofnum[body1] - 1) + da2 = int(body_dofadr[body2] + body_dofnum[body2] - 1) if is_sparse: + # count non-zeros + pda1 = da1 + pda2 = da2 rownnz = int(0) + while pda1 >= 0 or pda2 >= 0: + da = wp.max(pda1, pda2) + # skip common dofs + if pda1 == da and pda2 == da: + break + if pda1 == da: + pda1 = dof_parentid[pda1] + if pda2 == da: + pda2 = dof_parentid[pda2] + rownnz += 1 + + # get rowadr + rowadr = wp.atomic_add(efc_nnz_out, worldid, rownnz) + if rowadr + rownnz > njmax_nnz_in: + return + efc_J_rowadr_out[worldid, efcid] = rowadr + efc_J_rownnz_out[worldid, efcid] = rownnz + + da = wp.max(da1, da2) + + if is_sparse: + nnz = int(0) dofid = int(da) else: dofid = int(nv - 1) while True: if is_sparse: - if da1 < 0 and da2 < 0: + if nnz >= rownnz: break else: if dofid < 0: @@ -1957,13 +2126,15 @@ def _contact_elliptic( J += frame[dimid - 3, xyz] * jac_dif if is_sparse: - sparseid = rowadr + rownnz + sparseid = rowadr + nnz efc_J_colind_out[worldid, 0, sparseid] = dofid efc_J_out[worldid, 0, sparseid] = J - rownnz += 1 + nnz += 1 else: efc_J_out[worldid, efcid, dofid] = J Jqvel += J * qvel_in[worldid, dofid] + if is_sparse and nnz >= rownnz: + break # Advance tree pointers and recompute da for next iteration if da1 == da: @@ -1980,9 +2151,6 @@ def _contact_elliptic( efc_J_out[worldid, efcid, dofid] = 0.0 dofid -= 1 - if is_sparse: - efc_J_rownnz_out[worldid, efcid] = rownnz - body_invweight0_id = worldid % body_invweight0.shape[0] invweight = body_invweight0[body_invweight0_id, body1][0] + body_invweight0[body_invweight0_id, body2][0] @@ -2013,561 +2181,599 @@ def _contact_elliptic( efc_type = ConstraintType.CONTACT_ELLIPTIC _efc_row( - opt_disableflags, - worldid, - timestep, - efcid, - pos_aref, - pos, - invweight, - ref, - solimp_in[conid], - includemargin, - Jqvel, - 0.0, - efc_type, - conid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, + opt_disableflags, + worldid, + timestep, + efcid, + pos_aref, + pos, + invweight, + ref, + solimp_in[conid], + includemargin, + Jqvel, + 0.0, + efc_type, + conid, + efc_type_out, + efc_id_out, + efc_pos_out, + efc_margin_out, + efc_D_out, + efc_vel_out, + efc_aref_out, + efc_frictionloss_out, ) @event_scope def make_constraint(m: types.Model, d: types.Data): """Creates constraint jacobians and other supporting data.""" + efc_nnz = wp.empty((d.nworld,), dtype=int) + wp.launch( _zero_constraint_counts, dim=d.nworld, - inputs=[d.ne, d.nf, d.nl, d.nefc], + inputs=[d.ne, d.nf, d.nl, d.nefc, efc_nnz], ) - if types.SPARSE_CONSTRAINT_JACOBIAN: - d.contact.efc_address.fill_(-1) - if not (m.opt.disableflags & types.DisableBit.CONSTRAINT): if not (m.opt.disableflags & types.DisableBit.EQUALITY): wp.launch( - _equality_connect, - dim=(d.nworld, m.eq_connect_adr.size), - inputs=[ - m.nv, - m.nsite, - m.opt.timestep, - m.opt.disableflags, - m.body_parentid, - m.body_rootid, - m.body_weldid, - m.body_dofnum, - m.body_dofadr, - m.body_invweight0, - m.dof_bodyid, - m.dof_parentid, - m.site_bodyid, - m.eq_obj1id, - m.eq_obj2id, - m.eq_objtype, - m.eq_solref, - m.eq_solimp, - m.eq_data, - SPARSE_CONSTRAINT_JACOBIAN, - m.eq_connect_adr, - d.qvel, - d.eq_active, - d.xpos, - d.xmat, - d.site_xpos, - d.subtree_com, - d.cdof, - d.njmax, - ], - outputs=[ - d.ne, - d.nefc, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _equality_connect, + dim=(d.nworld, m.eq_connect_adr.size), + inputs=[ + m.nv, + m.nsite, + m.opt.timestep, + m.opt.disableflags, + m.body_parentid, + m.body_rootid, + m.body_weldid, + m.body_dofnum, + m.body_dofadr, + m.body_invweight0, + m.dof_bodyid, + m.dof_parentid, + m.site_bodyid, + m.eq_obj1id, + m.eq_obj2id, + m.eq_objtype, + m.eq_solref, + m.eq_solimp, + m.eq_data, + m.is_sparse, + m.eq_connect_adr, + d.qvel, + d.eq_active, + d.xpos, + d.xmat, + d.site_xpos, + d.subtree_com, + d.cdof, + d.njmax, + d.njmax_nnz, + ], + outputs=[ + d.ne, + d.nefc, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) wp.launch( - _equality_weld, - dim=(d.nworld, m.eq_wld_adr.size), - inputs=[ - m.nv, - m.nsite, - m.opt.timestep, - m.opt.disableflags, - m.body_parentid, - m.body_rootid, - m.body_weldid, - m.body_dofnum, - m.body_dofadr, - m.body_invweight0, - m.dof_bodyid, - m.dof_parentid, - m.site_bodyid, - m.site_quat, - m.eq_obj1id, - m.eq_obj2id, - m.eq_objtype, - m.eq_solref, - m.eq_solimp, - m.eq_data, - SPARSE_CONSTRAINT_JACOBIAN, - m.eq_wld_adr, - d.qvel, - d.eq_active, - d.xpos, - d.xquat, - d.xmat, - d.site_xpos, - d.subtree_com, - d.cdof, - d.njmax, - ], - outputs=[ - d.ne, - d.nefc, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _equality_weld, + dim=(d.nworld, m.eq_wld_adr.size), + inputs=[ + m.nv, + m.nsite, + m.opt.timestep, + m.opt.disableflags, + m.body_parentid, + m.body_rootid, + m.body_weldid, + m.body_dofnum, + m.body_dofadr, + m.body_invweight0, + m.dof_bodyid, + m.dof_parentid, + m.site_bodyid, + m.site_quat, + m.eq_obj1id, + m.eq_obj2id, + m.eq_objtype, + m.eq_solref, + m.eq_solimp, + m.eq_data, + m.is_sparse, + m.eq_wld_adr, + d.qvel, + d.eq_active, + d.xpos, + d.xquat, + d.xmat, + d.site_xpos, + d.subtree_com, + d.cdof, + d.njmax, + d.njmax_nnz, + ], + outputs=[ + d.ne, + d.nefc, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) wp.launch( - _equality_joint, - dim=(d.nworld, m.eq_jnt_adr.size), - inputs=[ - m.nv, - m.opt.timestep, - m.opt.disableflags, - m.qpos0, - m.jnt_qposadr, - m.jnt_dofadr, - m.dof_invweight0, - m.eq_obj1id, - m.eq_obj2id, - m.eq_solref, - m.eq_solimp, - m.eq_data, - SPARSE_CONSTRAINT_JACOBIAN, - m.eq_jnt_adr, - d.qpos, - d.qvel, - d.eq_active, - d.njmax, - ], - outputs=[ - d.ne, - d.nefc, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _equality_joint, + dim=(d.nworld, m.eq_jnt_adr.size), + inputs=[ + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.qpos0, + m.jnt_qposadr, + m.jnt_dofadr, + m.dof_invweight0, + m.eq_obj1id, + m.eq_obj2id, + m.eq_solref, + m.eq_solimp, + m.eq_data, + m.is_sparse, + m.eq_jnt_adr, + d.qpos, + d.qvel, + d.eq_active, + d.njmax, + d.njmax_nnz, + ], + outputs=[ + d.ne, + d.nefc, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) wp.launch( - _equality_tendon, - dim=(d.nworld, m.eq_ten_adr.size), - inputs=[ - m.nv, - m.opt.timestep, - m.opt.disableflags, - m.eq_obj1id, - m.eq_obj2id, - m.eq_solref, - m.eq_solimp, - m.eq_data, - m.tendon_length0, - m.tendon_invweight0, - SPARSE_CONSTRAINT_JACOBIAN, - m.eq_ten_adr, - d.qvel, - d.eq_active, - d.ten_J, - d.ten_length, - d.njmax, - ], - outputs=[ - d.ne, - d.nefc, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _equality_tendon, + dim=(d.nworld, m.eq_ten_adr.size), + inputs=[ + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.eq_obj1id, + m.eq_obj2id, + m.eq_solref, + m.eq_solimp, + m.eq_data, + m.ten_J_rownnz, + m.ten_J_rowadr, + m.ten_J_colind, + m.tendon_length0, + m.tendon_invweight0, + m.is_sparse, + m.eq_ten_adr, + d.qvel, + d.eq_active, + d.ten_J, + d.ten_length, + d.njmax, + d.njmax_nnz, + ], + outputs=[ + d.ne, + d.nefc, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) wp.launch( - _equality_flex(SPARSE_CONSTRAINT_JACOBIAN), - dim=(d.nworld, m.eq_flex_adr.size, m.nflexedge), - inputs=[ - m.nv, - m.opt.timestep, - m.opt.disableflags, - m.flexedge_length0, - m.flexedge_invweight0, - m.flexedge_J_rownnz, - m.flexedge_J_rowadr, - m.flexedge_J_colind, - m.eq_solref, - m.eq_solimp, - m.eq_flex_adr, - d.qvel, - d.flexedge_J, - d.flexedge_length, - d.njmax, - ], - outputs=[ - d.ne, - d.nefc, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _equality_flex(m.is_sparse), + dim=(d.nworld, m.eq_flex_adr.size, m.nflexedge), + inputs=[ + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.flex_edgeadr, + m.flex_edgenum, + m.flexedge_length0, + m.flexedge_invweight0, + m.flexedge_J_rownnz, + m.flexedge_J_rowadr, + m.flexedge_J_colind, + m.eq_obj1id, + m.eq_solref, + m.eq_solimp, + m.eq_flex_adr, + d.qvel, + d.flexedge_J, + d.flexedge_length, + d.njmax, + d.njmax_nnz, + ], + outputs=[ + d.ne, + d.nefc, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) if not (m.opt.disableflags & types.DisableBit.FRICTIONLOSS): wp.launch( - _friction_dof, - dim=(d.nworld, m.nv), - inputs=[ - m.nv, - m.opt.timestep, - m.opt.disableflags, - m.dof_solref, - m.dof_solimp, - m.dof_frictionloss, - m.dof_invweight0, - SPARSE_CONSTRAINT_JACOBIAN, - d.qvel, - d.njmax, - ], - outputs=[ - d.nf, - d.nefc, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _friction_dof, + dim=(d.nworld, m.nv), + inputs=[ + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.dof_solref, + m.dof_solimp, + m.dof_frictionloss, + m.dof_invweight0, + m.is_sparse, + d.qvel, + d.njmax, + d.njmax_nnz, + ], + outputs=[ + d.nf, + d.nefc, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) wp.launch( - _friction_tendon, - dim=(d.nworld, m.ntendon), - inputs=[ - m.nv, - m.opt.timestep, - m.opt.disableflags, - m.tendon_solref_fri, - m.tendon_solimp_fri, - m.tendon_frictionloss, - m.tendon_invweight0, - SPARSE_CONSTRAINT_JACOBIAN, - d.qvel, - d.ten_J, - d.njmax, - ], - outputs=[ - d.nf, - d.nefc, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _friction_tendon, + dim=(d.nworld, m.ntendon), + inputs=[ + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.ten_J_rownnz, + m.ten_J_rowadr, + m.ten_J_colind, + m.tendon_solref_fri, + m.tendon_solimp_fri, + m.tendon_frictionloss, + m.tendon_invweight0, + m.is_sparse, + d.qvel, + d.ten_J, + d.njmax, + d.njmax_nnz, + ], + outputs=[ + d.nf, + d.nefc, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) # limit if not (m.opt.disableflags & types.DisableBit.LIMIT): wp.launch( - _limit_ball, - dim=(d.nworld, m.jnt_limited_ball_adr.size), - inputs=[ - m.nv, - m.opt.timestep, - m.opt.disableflags, - m.jnt_qposadr, - m.jnt_dofadr, - m.jnt_solref, - m.jnt_solimp, - m.jnt_range, - m.jnt_margin, - m.dof_invweight0, - SPARSE_CONSTRAINT_JACOBIAN, - m.jnt_limited_ball_adr, - d.qpos, - d.qvel, - d.njmax, - ], - outputs=[ - d.nl, - d.nefc, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _limit_ball, + dim=(d.nworld, m.jnt_limited_ball_adr.size), + inputs=[ + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.jnt_qposadr, + m.jnt_dofadr, + m.jnt_solref, + m.jnt_solimp, + m.jnt_range, + m.jnt_margin, + m.dof_invweight0, + m.is_sparse, + m.jnt_limited_ball_adr, + d.qpos, + d.qvel, + d.njmax, + d.njmax_nnz, + ], + outputs=[ + d.nl, + d.nefc, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) wp.launch( - _limit_slide_hinge, - dim=(d.nworld, m.jnt_limited_slide_hinge_adr.size), - inputs=[ - m.nv, - m.opt.timestep, - m.opt.disableflags, - m.jnt_qposadr, - m.jnt_dofadr, - m.jnt_solref, - m.jnt_solimp, - m.jnt_range, - m.jnt_margin, - m.dof_invweight0, - SPARSE_CONSTRAINT_JACOBIAN, - m.jnt_limited_slide_hinge_adr, - d.qpos, - d.qvel, - d.njmax, - ], - outputs=[ - d.nl, - d.nefc, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _limit_slide_hinge, + dim=(d.nworld, m.jnt_limited_slide_hinge_adr.size), + inputs=[ + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.jnt_qposadr, + m.jnt_dofadr, + m.jnt_solref, + m.jnt_solimp, + m.jnt_range, + m.jnt_margin, + m.dof_invweight0, + m.is_sparse, + m.jnt_limited_slide_hinge_adr, + d.qpos, + d.qvel, + d.njmax, + d.njmax_nnz, + ], + outputs=[ + d.nl, + d.nefc, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) wp.launch( - _limit_tendon, - dim=(d.nworld, m.tendon_limited_adr.size), - inputs=[ - m.nv, - m.opt.timestep, - m.opt.disableflags, - m.jnt_dofadr, - m.tendon_adr, - m.tendon_num, - m.tendon_solref_lim, - m.tendon_solimp_lim, - m.tendon_range, - m.tendon_margin, - m.tendon_invweight0, - m.wrap_type, - m.wrap_objid, - SPARSE_CONSTRAINT_JACOBIAN, - m.tendon_limited_adr, - d.qvel, - d.ten_J, - d.ten_length, - d.njmax, - ], - outputs=[ - d.nl, - d.nefc, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _limit_tendon, + dim=(d.nworld, m.tendon_limited_adr.size), + inputs=[ + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.ten_J_rownnz, + m.ten_J_rowadr, + m.ten_J_colind, + m.tendon_solref_lim, + m.tendon_solimp_lim, + m.tendon_range, + m.tendon_margin, + m.tendon_invweight0, + m.is_sparse, + m.tendon_limited_adr, + d.qvel, + d.ten_J, + d.ten_length, + d.njmax, + d.njmax_nnz, + ], + outputs=[ + d.nl, + d.nefc, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) # contact if not (m.opt.disableflags & types.DisableBit.CONTACT): if m.opt.cone == types.ConeType.PYRAMIDAL: wp.launch( - _contact_pyramidal, - dim=(d.naconmax, m.nmaxpyramid), - inputs=[ - m.nv, - m.opt.timestep, - m.opt.disableflags, - m.opt.impratio_invsqrt, - m.body_parentid, - m.body_rootid, - m.body_weldid, - m.body_dofnum, - m.body_dofadr, - m.body_invweight0, - m.dof_bodyid, - m.dof_parentid, - m.geom_bodyid, - SPARSE_CONSTRAINT_JACOBIAN, - d.qvel, - d.subtree_com, - d.cdof, - d.njmax, - d.nacon, - d.contact.dist, - d.contact.dim, - d.contact.includemargin, - d.contact.worldid, - d.contact.geom, - d.contact.pos, - d.contact.frame, - d.contact.friction, - d.contact.solref, - d.contact.solimp, - d.contact.type, - ], - outputs=[ - d.nefc, - d.contact.efc_address, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _contact_pyramidal, + dim=(d.naconmax, m.nmaxpyramid), + inputs=[ + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.opt.impratio_invsqrt, + m.body_parentid, + m.body_rootid, + m.body_weldid, + m.body_dofnum, + m.body_dofadr, + m.body_invweight0, + m.dof_bodyid, + m.dof_parentid, + m.geom_bodyid, + m.flex_vertadr, + m.flex_vertbodyid, + m.is_sparse, + d.qvel, + d.subtree_com, + d.cdof, + d.njmax, + d.njmax_nnz, + d.nacon, + d.contact.dist, + d.contact.dim, + d.contact.includemargin, + d.contact.worldid, + d.contact.geom, + d.contact.flex, + d.contact.vert, + d.contact.pos, + d.contact.frame, + d.contact.friction, + d.contact.solref, + d.contact.solimp, + d.contact.type, + ], + outputs=[ + d.nefc, + d.contact.efc_address, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) elif m.opt.cone == types.ConeType.ELLIPTIC: wp.launch( - _contact_elliptic, - dim=(d.naconmax, m.nmaxcondim), - inputs=[ - m.nv, - m.opt.timestep, - m.opt.disableflags, - m.opt.impratio_invsqrt, - m.body_parentid, - m.body_rootid, - m.body_weldid, - m.body_dofnum, - m.body_dofadr, - m.body_invweight0, - m.dof_bodyid, - m.dof_parentid, - m.geom_bodyid, - SPARSE_CONSTRAINT_JACOBIAN, - d.qvel, - d.subtree_com, - d.cdof, - d.njmax, - d.nacon, - d.contact.dist, - d.contact.dim, - d.contact.includemargin, - d.contact.worldid, - d.contact.geom, - d.contact.pos, - d.contact.frame, - d.contact.friction, - d.contact.solref, - d.contact.solreffriction, - d.contact.solimp, - d.contact.type, - ], - outputs=[ - d.nefc, - d.contact.efc_address, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], + _contact_elliptic, + dim=(d.naconmax, m.nmaxcondim), + inputs=[ + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.opt.impratio_invsqrt, + m.body_parentid, + m.body_rootid, + m.body_weldid, + m.body_dofnum, + m.body_dofadr, + m.body_invweight0, + m.dof_bodyid, + m.dof_parentid, + m.geom_bodyid, + m.flex_vertadr, + m.flex_vertbodyid, + m.is_sparse, + d.qvel, + d.subtree_com, + d.cdof, + d.njmax, + d.njmax_nnz, + d.nacon, + d.contact.dist, + d.contact.dim, + d.contact.includemargin, + d.contact.worldid, + d.contact.geom, + d.contact.flex, + d.contact.vert, + d.contact.pos, + d.contact.frame, + d.contact.friction, + d.contact.solref, + d.contact.solreffriction, + d.contact.solimp, + d.contact.type, + ], + outputs=[ + d.nefc, + d.contact.efc_address, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + efc_nnz, + ], ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/derivative.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/derivative.py index 406c36c4aa..20da751ac1 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/derivative.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/derivative.py @@ -15,6 +15,7 @@ import warp as wp +from mujoco.mjx.third_party.mujoco_warp._src.support import next_act from mujoco.mjx.third_party.mujoco_warp._src.types import BiasType from mujoco.mjx.third_party.mujoco_warp._src.types import Data from mujoco.mjx.third_party.mujoco_warp._src.types import DisableBit @@ -30,18 +31,24 @@ @wp.kernel def _qderiv_actuator_passive_vel( # Model: + opt_timestep: wp.array(dtype=float), actuator_dyntype: wp.array(dtype=int), actuator_gaintype: wp.array(dtype=int), actuator_biastype: wp.array(dtype=int), actuator_actadr: wp.array(dtype=int), actuator_actnum: wp.array(dtype=int), actuator_forcelimited: wp.array(dtype=bool), + actuator_actlimited: wp.array(dtype=bool), + actuator_dynprm: wp.array2d(dtype=vec10f), actuator_gainprm: wp.array2d(dtype=vec10f), actuator_biasprm: wp.array2d(dtype=vec10f), + actuator_actearly: wp.array(dtype=bool), actuator_forcerange: wp.array2d(dtype=wp.vec2), + actuator_actrange: wp.array2d(dtype=wp.vec2), # Data in: act_in: wp.array2d(dtype=float), ctrl_in: wp.array2d(dtype=float), + act_dot_in: wp.array2d(dtype=float), actuator_force_in: wp.array2d(dtype=float), # Out: vel_out: wp.array2d(dtype=float), @@ -76,9 +83,24 @@ def _qderiv_actuator_passive_vel( vel = float(bias) if actuator_dyntype[actid] != DynType.NONE: if gain != 0.0: - act_first = actuator_actadr[actid] - act_last = act_first + actuator_actnum[actid] - 1 - vel += gain * act_in[worldid, act_last] + act_adr = actuator_actadr[actid] + actuator_actnum[actid] - 1 + + # use next activation if actearly is set (matching forward pass) + if actuator_actearly[actid]: + act = next_act( + opt_timestep[worldid % opt_timestep.shape[0]], + actuator_dyntype[actid], + actuator_dynprm[worldid % actuator_dynprm.shape[0], actid], + actuator_actrange[worldid % actuator_actrange.shape[0], actid], + act_in[worldid, act_adr], + act_dot_in[worldid, act_adr], + 1.0, + actuator_actlimited[actid], + ) + else: + act = act_in[worldid, act_adr] + + vel += gain * act else: if gain != 0.0: vel += gain * ctrl_in[worldid, actid] @@ -95,21 +117,20 @@ def _nonzero_mask(x: float) -> float: @wp.kernel -def _qderiv_actuator_passive_actuation_sparse( - # Model: - nu: int, - is_sparse: bool, - # Data in: - moment_rownnz_in: wp.array2d(dtype=int), - moment_rowadr_in: wp.array2d(dtype=int), - moment_colind_in: wp.array2d(dtype=int), - actuator_moment_in: wp.array2d(dtype=float), - # In: - vel_in: wp.array2d(dtype=float), - qMi: wp.array(dtype=int), - qMj: wp.array(dtype=int), - # Out: - qDeriv_out: wp.array3d(dtype=float), +def _qderiv_actuator_passive_actuation_dense( + # Model: + nu: int, + # Data in: + moment_rownnz_in: wp.array2d(dtype=int), + moment_rowadr_in: wp.array2d(dtype=int), + moment_colind_in: wp.array2d(dtype=int), + actuator_moment_in: wp.array2d(dtype=float), + # In: + vel_in: wp.array2d(dtype=float), + qMi: wp.array(dtype=int), + qMj: wp.array(dtype=int), + # Out: + qDeriv_out: wp.array3d(dtype=float), ): worldid, elemid = wp.tid() @@ -142,12 +163,63 @@ def _qderiv_actuator_passive_actuation_sparse( qderiv_contrib += moment_i * moment_j * vel - if is_sparse: - qDeriv_out[worldid, 0, elemid] = qderiv_contrib - else: - qDeriv_out[worldid, dofiid, dofjid] = qderiv_contrib - if dofiid != dofjid: - qDeriv_out[worldid, dofjid, dofiid] = qderiv_contrib + qDeriv_out[worldid, dofiid, dofjid] = qderiv_contrib + if dofiid != dofjid: + qDeriv_out[worldid, dofjid, dofiid] = qderiv_contrib + + +@wp.kernel +def _qderiv_actuator_passive_actuation_sparse( + # Model: + M_rownnz: wp.array(dtype=int), + M_rowadr: wp.array(dtype=int), + # Data in: + moment_rownnz_in: wp.array2d(dtype=int), + moment_rowadr_in: wp.array2d(dtype=int), + moment_colind_in: wp.array2d(dtype=int), + actuator_moment_in: wp.array2d(dtype=float), + # In: + vel_in: wp.array2d(dtype=float), + qMj: wp.array(dtype=int), + # Out: + qDeriv_out: wp.array3d(dtype=float), +): + worldid, actid = wp.tid() + + vel = vel_in[worldid, actid] + if vel == 0.0: + return + + rownnz = moment_rownnz_in[worldid, actid] + rowadr = moment_rowadr_in[worldid, actid] + + for i in range(rownnz): + rowadri = rowadr + i + moment_i = actuator_moment_in[worldid, rowadri] + if moment_i == 0.0: + continue + dofi = moment_colind_in[worldid, rowadri] + + for j in range(i + 1): + rowadrj = rowadr + j + moment_j = actuator_moment_in[worldid, rowadrj] + if moment_j == 0.0: + continue + dofj = moment_colind_in[worldid, rowadrj] + + contrib = moment_i * moment_j * vel + + # Search the corresponding elemid + # TODO: This could be precalculated for improved performance + row = dofi + col = dofj + row_startk = M_rowadr[row] - 1 + row_nnz = M_rownnz[row] + for k in range(row_nnz): + row_startk += 1 + if qMj[row_startk] == col: + wp.atomic_add(qDeriv_out[worldid, 0], row_startk, contrib) + break @wp.kernel @@ -176,7 +248,7 @@ def _qderiv_actuator_passive( else: qderiv = qDeriv_in[worldid, dofiid, dofjid] - if not opt_disableflags & DisableBit.DAMPER and dofiid == dofjid: + if not (opt_disableflags & DisableBit.DAMPER) and dofiid == dofjid: qderiv -= dof_damping[worldid % dof_damping.shape[0], dofiid] qderiv *= opt_timestep[worldid % opt_timestep.shape[0]] @@ -196,10 +268,13 @@ def _qderiv_tendon_damping( # Model: ntendon: int, opt_timestep: wp.array(dtype=float), + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), tendon_damping: wp.array2d(dtype=float), is_sparse: bool, # Data in: - ten_J_in: wp.array3d(dtype=float), + ten_J_in: wp.array2d(dtype=float), # In: qMi: wp.array(dtype=int), qMj: wp.array(dtype=int), @@ -213,7 +288,24 @@ def _qderiv_tendon_damping( qderiv = float(0.0) tendon_damping_id = worldid % tendon_damping.shape[0] for tenid in range(ntendon): - qderiv -= ten_J_in[worldid, tenid, dofiid] * ten_J_in[worldid, tenid, dofjid] * tendon_damping[tendon_damping_id, tenid] + damping = tendon_damping[tendon_damping_id, tenid] + if damping == 0.0: + continue + + rownnz = ten_J_rownnz[tenid] + rowadr = ten_J_rowadr[tenid] + Ji = float(0.0) + Jj = float(0.0) + for k in range(rownnz): + if Ji != 0.0 and Jj != 0.0: + break + sparseid = rowadr + k + colind = ten_J_colind[sparseid] + if colind == dofiid: + Ji = ten_J_in[worldid, sparseid] + if colind == dofjid: + Jj = ten_J_in[worldid, sparseid] + qderiv -= Ji * Jj * damping qderiv *= opt_timestep[worldid % opt_timestep.shape[0]] @@ -242,43 +334,47 @@ def deriv_smooth_vel(m: Model, d: Data, out: wp.array2d(dtype=float)): if ~(m.opt.disableflags & (DisableBit.ACTUATION | DisableBit.DAMPER)): # TODO(team): only clear elements not set by _qderiv_actuator_passive out.zero_() - if m.nu > 0 and not m.opt.disableflags & DisableBit.ACTUATION: + if m.nu > 0 and not (m.opt.disableflags & DisableBit.ACTUATION): vel = wp.empty((d.nworld, m.nu), dtype=float) wp.launch( _qderiv_actuator_passive_vel, dim=(d.nworld, m.nu), inputs=[ + m.opt.timestep, m.actuator_dyntype, m.actuator_gaintype, m.actuator_biastype, m.actuator_actadr, m.actuator_actnum, m.actuator_forcelimited, + m.actuator_actlimited, + m.actuator_dynprm, m.actuator_gainprm, m.actuator_biasprm, + m.actuator_actearly, m.actuator_forcerange, + m.actuator_actrange, d.act, d.ctrl, + d.act_dot, d.actuator_force, ], outputs=[vel], ) - wp.launch( + if m.is_sparse: + wp.launch( _qderiv_actuator_passive_actuation_sparse, + dim=(d.nworld, m.nu), + inputs=[m.M_rownnz, m.M_rowadr, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment, vel, qMj], + outputs=[out], + ) + else: + wp.launch( + _qderiv_actuator_passive_actuation_dense, dim=(d.nworld, qMi.size), - inputs=[ - m.nu, - m.is_sparse, - d.moment_rownnz, - d.moment_rowadr, - d.moment_colind, - d.actuator_moment, - vel, - qMi, - qMj, - ], + inputs=[m.nu, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment, vel, qMi, qMj], outputs=[out], - ) + ) wp.launch( _qderiv_actuator_passive, dim=(d.nworld, qMi.size), @@ -298,11 +394,22 @@ def deriv_smooth_vel(m: Model, d: Data, out: wp.array2d(dtype=float)): # TODO(team): directly utilize qM for these settings wp.copy(out, d.qM) - if not m.opt.disableflags & DisableBit.DAMPER: + if not (m.opt.disableflags & DisableBit.DAMPER): wp.launch( _qderiv_tendon_damping, dim=(d.nworld, qMi.size), - inputs=[m.ntendon, m.opt.timestep, m.tendon_damping, m.is_sparse, d.ten_J, qMi, qMj], + inputs=[ + m.ntendon, + m.opt.timestep, + m.ten_J_rownnz, + m.ten_J_rowadr, + m.ten_J_colind, + m.tendon_damping, + m.is_sparse, + d.ten_J, + qMi, + qMj, + ], outputs=[out], ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/forward.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/forward.py index 0dc3de14eb..64bdd91f5f 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/forward.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/forward.py @@ -15,6 +15,8 @@ from typing import Optional +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src import collision_driver from mujoco.mjx.third_party.mujoco_warp._src import constraint from mujoco.mjx.third_party.mujoco_warp._src import derivative @@ -25,7 +27,9 @@ from mujoco.mjx.third_party.mujoco_warp._src import smooth from mujoco.mjx.third_party.mujoco_warp._src import solver from mujoco.mjx.third_party.mujoco_warp._src import util_misc +from mujoco.mjx.third_party.mujoco_warp._src.support import next_act from mujoco.mjx.third_party.mujoco_warp._src.support import xfrc_accumulate +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINVAL from mujoco.mjx.third_party.mujoco_warp._src.types import BiasType from mujoco.mjx.third_party.mujoco_warp._src.types import Data from mujoco.mjx.third_party.mujoco_warp._src.types import DisableBit @@ -34,14 +38,12 @@ from mujoco.mjx.third_party.mujoco_warp._src.types import GainType from mujoco.mjx.third_party.mujoco_warp._src.types import IntegratorType from mujoco.mjx.third_party.mujoco_warp._src.types import JointType -from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINVAL from mujoco.mjx.third_party.mujoco_warp._src.types import Model from mujoco.mjx.third_party.mujoco_warp._src.types import TileSet from mujoco.mjx.third_party.mujoco_warp._src.types import TrnType from mujoco.mjx.third_party.mujoco_warp._src.types import vec10f from mujoco.mjx.third_party.mujoco_warp._src.warp_util import cache_kernel from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope -import warp as wp wp.set_module_options({"enable_backward": False}) @@ -127,55 +129,24 @@ def _next_velocity( qvel_out[worldid, dofid] = qvel_in[worldid, dofid] + qacc_scale_in * qacc_in[worldid, dofid] * timestep -# TODO(team): kernel analyzer array slice? -@wp.func -def _next_act( +@wp.kernel +def _next_activation( # Model: - opt_timestep: float, # kernel_analyzer: ignore - actuator_dyntype: int, # kernel_analyzer: ignore - actuator_dynprm: vec10f, # kernel_analyzer: ignore - actuator_actrange: wp.vec2, # kernel_analyzer: ignore - # Data In: - act_in: float, # kernel_analyzer: ignore - act_dot_in: float, # kernel_analyzer: ignore + opt_timestep: wp.array(dtype=float), + actuator_dyntype: wp.array(dtype=int), + actuator_actadr: wp.array(dtype=int), + actuator_actnum: wp.array(dtype=int), + actuator_actlimited: wp.array(dtype=bool), + actuator_dynprm: wp.array2d(dtype=vec10f), + actuator_actrange: wp.array2d(dtype=wp.vec2), + # Data in: + act_in: wp.array2d(dtype=float), + act_dot_in: wp.array2d(dtype=float), # In: act_dot_scale: float, - clamp: bool, -) -> float: - # advance actuation - if actuator_dyntype == DynType.FILTEREXACT: - tau = wp.max(MJ_MINVAL, actuator_dynprm[0]) - act = act_in + act_dot_scale * act_dot_in * tau * (1.0 - wp.exp(-opt_timestep / tau)) - elif actuator_dyntype == DynType.USER: - return act_in - else: - act = act_in + act_dot_scale * act_dot_in * opt_timestep - - # clamp to actrange - if clamp: - act = wp.clamp(act, actuator_actrange[0], actuator_actrange[1]) - - return act - - -@wp.kernel -def _next_activation( - # Model: - opt_timestep: wp.array(dtype=float), - actuator_dyntype: wp.array(dtype=int), - actuator_actadr: wp.array(dtype=int), - actuator_actnum: wp.array(dtype=int), - actuator_actlimited: wp.array(dtype=bool), - actuator_dynprm: wp.array2d(dtype=vec10f), - actuator_actrange: wp.array2d(dtype=wp.vec2), - # Data in: - act_in: wp.array2d(dtype=float), - act_dot_in: wp.array2d(dtype=float), - # In: - act_dot_scale: float, - limit: bool, - # Data out: - act_out: wp.array2d(dtype=float), + limit: bool, + # Data out: + act_out: wp.array2d(dtype=float), ): worldid, uid = wp.tid() opt_timestep_id = worldid % opt_timestep.shape[0] @@ -184,15 +155,15 @@ def _next_activation( actadr = actuator_actadr[uid] actnum = actuator_actnum[uid] for j in range(actadr, actadr + actnum): - act = _next_act( - opt_timestep[opt_timestep_id], - actuator_dyntype[uid], - actuator_dynprm[actuator_dynprm_id, uid], - actuator_actrange[actuator_actrange_id, uid], - act_in[worldid, j], - act_dot_in[worldid, j], - act_dot_scale, - limit and actuator_actlimited[uid], + act = next_act( + opt_timestep[opt_timestep_id], + actuator_dyntype[uid], + actuator_dynprm[actuator_dynprm_id, uid], + actuator_actrange[actuator_actrange_id, uid], + act_in[worldid, j], + act_dot_in[worldid, j], + act_dot_scale, + limit and actuator_actlimited[uid], ) act_out[worldid, j] = act @@ -201,12 +172,16 @@ def _next_activation( def _next_time( # Model: opt_timestep: wp.array(dtype=float), + is_sparse: bool, # Data in: nefc_in: wp.array(dtype=int), time_in: wp.array(dtype=float), + efc_J_rownnz_in: wp.array2d(dtype=int), + efc_J_rowadr_in: wp.array2d(dtype=int), nworld_in: int, naconmax_in: int, njmax_in: int, + njmax_nnz_in: int, nacon_in: wp.array(dtype=int), ncollision_in: wp.array(dtype=int), # Data out: @@ -218,6 +193,11 @@ def _next_time( if nefc > njmax_in: wp.printf("nefc overflow - please increase njmax to %u\n", nefc) + elif nefc > 0 and is_sparse: + efcid = wp.min(nefc, njmax_in) - 1 + efc_nnz = efc_J_rowadr_in[worldid, efcid] + efc_J_rownnz_in[worldid, efcid] + if efc_nnz > njmax_nnz_in: + wp.printf("njmax_nnz overflow - please increase njmax_nnz to %u\n", efc_nnz) if worldid == 0: ncollision = ncollision_in[0] @@ -236,22 +216,22 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None) # advance activations wp.launch( - _next_activation, - dim=(d.nworld, m.nu), - inputs=[ - m.opt.timestep, - m.actuator_dyntype, - m.actuator_actadr, - m.actuator_actnum, - m.actuator_actlimited, - m.actuator_dynprm, - m.actuator_actrange, - d.act, - d.act_dot, - 1.0, - True, - ], - outputs=[d.act], + _next_activation, + dim=(d.nworld, m.nu), + inputs=[ + m.opt.timestep, + m.actuator_dyntype, + m.actuator_actadr, + m.actuator_actnum, + m.actuator_actlimited, + m.actuator_dynprm, + m.actuator_actrange, + d.act, + d.act_dot, + 1.0, + True, + ], + outputs=[d.act], ) wp.launch( @@ -274,7 +254,20 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None) wp.launch( _next_time, dim=d.nworld, - inputs=[m.opt.timestep, d.nefc, d.time, d.nworld, d.naconmax, d.njmax, d.nacon, d.ncollision], + inputs=[ + m.opt.timestep, + m.is_sparse, + d.nefc, + d.time, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.nworld, + d.naconmax, + d.njmax, + d.njmax_nnz, + d.nacon, + d.ncollision, + ], outputs=[d.time], ) @@ -294,9 +287,7 @@ def _euler_damp_qfrc_sparse( timestep = opt_timestep[worldid % opt_timestep.shape[0]] adr = dof_Madr[tid] - qM_integration_out[worldid, 0, adr] += ( - timestep * dof_damping[worldid % dof_damping.shape[0], tid] - ) + qM_integration_out[worldid, 0, adr] += timestep * dof_damping[worldid % dof_damping.shape[0], tid] @cache_kernel @@ -336,7 +327,7 @@ def euler_dense( def euler(m: Model, d: Data): """Euler integrator, semi-implicit in velocity.""" # integrate damping implicitly - if not m.opt.disableflags & (DisableBit.EULERDAMP | DisableBit.DAMPER): + if not (m.opt.disableflags & (DisableBit.EULERDAMP | DisableBit.DAMPER)): qacc = wp.empty((d.nworld, m.nv), dtype=float) if m.is_sparse: qM = wp.clone(d.qM) @@ -390,22 +381,22 @@ def _rk_perturb_state( # activation if m.na and act_t0 is not None: wp.launch( - _next_activation, - dim=(d.nworld, m.nu), - inputs=[ - m.opt.timestep, - m.actuator_dyntype, - m.actuator_actadr, - m.actuator_actnum, - m.actuator_actlimited, - m.actuator_dynprm, - m.actuator_actrange, - act_t0, - d.act_dot, - scale, - False, - ], - outputs=[d.act], + _next_activation, + dim=(d.nworld, m.nu), + inputs=[ + m.opt.timestep, + m.actuator_dyntype, + m.actuator_actadr, + m.actuator_actnum, + m.actuator_actlimited, + m.actuator_dynprm, + m.actuator_actrange, + act_t0, + d.act_dot, + scale, + False, + ], + outputs=[d.act], ) @@ -548,14 +539,14 @@ def fwd_position(m: Model, d: Data, factorize: bool = True): @wp.kernel def _actuator_velocity( - # Data in: - qvel_in: wp.array2d(dtype=float), - moment_rownnz_in: wp.array2d(dtype=int), - moment_rowadr_in: wp.array2d(dtype=int), - moment_colind_in: wp.array2d(dtype=int), - actuator_moment_in: wp.array2d(dtype=float), - # Data out: - actuator_velocity_out: wp.array2d(dtype=float), + # Data in: + qvel_in: wp.array2d(dtype=float), + moment_rownnz_in: wp.array2d(dtype=int), + moment_rowadr_in: wp.array2d(dtype=int), + moment_colind_in: wp.array2d(dtype=int), + actuator_moment_in: wp.array2d(dtype=float), + # Data out: + actuator_velocity_out: wp.array2d(dtype=float), ): worldid, actid = wp.tid() @@ -571,50 +562,49 @@ def _actuator_velocity( actuator_velocity_out[worldid, actid] = vel -@cache_kernel -def _tendon_velocity(nv: int): - @wp.kernel(module="unique", enable_backward=False) - def tendon_velocity( - # Data in: - qvel_in: wp.array2d(dtype=float), - ten_J_in: wp.array3d(dtype=float), - # Data out: - ten_velocity_out: wp.array2d(dtype=float), - ): - worldid, tenid = wp.tid() - ten_J_tile = wp.tile_load(ten_J_in[worldid, tenid], shape=wp.static(nv)) - qvel_tile = wp.tile_load(qvel_in[worldid], shape=wp.static(nv)) - ten_J_qvel_tile = wp.tile_map(wp.mul, ten_J_tile, qvel_tile) - ten_velocity_tile = wp.tile_reduce(wp.add, ten_J_qvel_tile) - ten_velocity_out[worldid, tenid] = ten_velocity_tile[0] +@wp.kernel +def _tendon_velocity( + # Model: + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + # Data in: + qvel_in: wp.array2d(dtype=float), + ten_J_in: wp.array2d(dtype=float), + # Data out: + ten_velocity_out: wp.array2d(dtype=float), +): + worldid, tenid = wp.tid() - return tendon_velocity + velocity = float(0.0) + rownnz = ten_J_rownnz[tenid] + rowadr = ten_J_rowadr[tenid] + for i in range(rownnz): + sparseid = rowadr + i + J = ten_J_in[worldid, sparseid] + if J != 0.0: + colind = ten_J_colind[sparseid] + velocity += J * qvel_in[worldid, colind] + + ten_velocity_out[worldid, tenid] = velocity @event_scope def fwd_velocity(m: Model, d: Data): """Velocity-dependent computations.""" - wp.launch_tiled( - _actuator_velocity, - dim=(d.nworld, m.nu), - inputs=[ - d.qvel, - d.moment_rownnz, - d.moment_rowadr, - d.moment_colind, - d.actuator_moment, - ], - outputs=[d.actuator_velocity], - block_dim=m.block_dim.actuator_velocity, + wp.launch( + _actuator_velocity, + dim=(d.nworld, m.nu), + inputs=[d.qvel, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment], + outputs=[d.actuator_velocity], + block_dim=m.block_dim.actuator_velocity, ) - # TODO(team): sparse version - wp.launch_tiled( - _tendon_velocity(m.nv), + wp.launch( + _tendon_velocity, dim=(d.nworld, m.ntendon), - inputs=[d.qvel, d.ten_J], + inputs=[m.ten_J_rownnz, m.ten_J_rowadr, m.ten_J_colind, d.qvel, d.ten_J], outputs=[d.ten_velocity], - block_dim=m.block_dim.tendon_velocity, ) smooth.com_vel(m, d) @@ -625,36 +615,36 @@ def fwd_velocity(m: Model, d: Data): @wp.kernel def _actuator_force( - # Model: - na: int, - opt_timestep: wp.array(dtype=float), - actuator_dyntype: wp.array(dtype=int), - actuator_gaintype: wp.array(dtype=int), - actuator_biastype: wp.array(dtype=int), - actuator_actadr: wp.array(dtype=int), - actuator_actnum: wp.array(dtype=int), - actuator_ctrllimited: wp.array(dtype=bool), - actuator_forcelimited: wp.array(dtype=bool), - actuator_actlimited: wp.array(dtype=bool), - actuator_dynprm: wp.array2d(dtype=vec10f), - actuator_gainprm: wp.array2d(dtype=vec10f), - actuator_biasprm: wp.array2d(dtype=vec10f), - actuator_actearly: wp.array(dtype=bool), - actuator_ctrlrange: wp.array2d(dtype=wp.vec2), - actuator_forcerange: wp.array2d(dtype=wp.vec2), - actuator_actrange: wp.array2d(dtype=wp.vec2), - actuator_acc0: wp.array2d(dtype=float), - actuator_lengthrange: wp.array2d(dtype=wp.vec2), - # Data in: - act_in: wp.array2d(dtype=float), - ctrl_in: wp.array2d(dtype=float), - actuator_length_in: wp.array2d(dtype=float), - actuator_velocity_in: wp.array2d(dtype=float), - # In: - dsbl_clampctrl: int, - # Data out: - act_dot_out: wp.array2d(dtype=float), - actuator_force_out: wp.array2d(dtype=float), + # Model: + na: int, + opt_timestep: wp.array(dtype=float), + actuator_dyntype: wp.array(dtype=int), + actuator_gaintype: wp.array(dtype=int), + actuator_biastype: wp.array(dtype=int), + actuator_actadr: wp.array(dtype=int), + actuator_actnum: wp.array(dtype=int), + actuator_ctrllimited: wp.array(dtype=bool), + actuator_forcelimited: wp.array(dtype=bool), + actuator_actlimited: wp.array(dtype=bool), + actuator_dynprm: wp.array2d(dtype=vec10f), + actuator_gainprm: wp.array2d(dtype=vec10f), + actuator_biasprm: wp.array2d(dtype=vec10f), + actuator_actearly: wp.array(dtype=bool), + actuator_ctrlrange: wp.array2d(dtype=wp.vec2), + actuator_forcerange: wp.array2d(dtype=wp.vec2), + actuator_actrange: wp.array2d(dtype=wp.vec2), + actuator_acc0: wp.array2d(dtype=float), + actuator_lengthrange: wp.array2d(dtype=wp.vec2), + # Data in: + act_in: wp.array2d(dtype=float), + ctrl_in: wp.array2d(dtype=float), + actuator_length_in: wp.array2d(dtype=float), + actuator_velocity_in: wp.array2d(dtype=float), + # In: + dsbl_clampctrl: int, + # Data out: + act_dot_out: wp.array2d(dtype=float), + actuator_force_out: wp.array2d(dtype=float), ): worldid, uid = wp.tid() @@ -693,7 +683,7 @@ def _actuator_force( if dyntype == DynType.INTEGRATOR or dyntype == DynType.NONE: act = act_in[worldid, act_last] - ctrl_act = _next_act( + ctrl_act = next_act( opt_timestep[worldid % opt_timestep.shape[0]], dyntype, dynprm, @@ -720,9 +710,7 @@ def _actuator_force( gain = gainprm[0] + gainprm[1] * length + gainprm[2] * velocity elif gaintype == GainType.MUSCLE: acc0 = actuator_acc0[worldid % actuator_acc0.shape[0], uid] - lengthrange = actuator_lengthrange[ - worldid % actuator_lengthrange.shape[0], uid - ] + lengthrange = actuator_lengthrange[worldid % actuator_lengthrange.shape[0], uid] gain = util_misc.muscle_gain(length, velocity, lengthrange, acc0, gainprm) # GainType.USER: gain stays 0, modified by act_gain_callback @@ -735,9 +723,7 @@ def _actuator_force( bias = biasprm[0] + biasprm[1] * length + biasprm[2] * velocity elif biastype == BiasType.MUSCLE: acc0 = actuator_acc0[worldid % actuator_acc0.shape[0], uid] - lengthrange = actuator_lengthrange[ - worldid % actuator_lengthrange.shape[0], uid - ] + lengthrange = actuator_lengthrange[worldid % actuator_lengthrange.shape[0], uid] bias = util_misc.muscle_bias(length, lengthrange, acc0, biasprm) force = gain * ctrl_act + bias @@ -795,14 +781,14 @@ def _tendon_actuator_force_clamp( @wp.kernel def _qfrc_actuator( - # Data in: - moment_rownnz_in: wp.array2d(dtype=int), - moment_rowadr_in: wp.array2d(dtype=int), - moment_colind_in: wp.array2d(dtype=int), - actuator_moment_in: wp.array2d(dtype=float), - actuator_force_in: wp.array2d(dtype=float), - # Data out: - qfrc_actuator_out: wp.array2d(dtype=float), + # Data in: + moment_rownnz_in: wp.array2d(dtype=int), + moment_rowadr_in: wp.array2d(dtype=int), + moment_colind_in: wp.array2d(dtype=int), + actuator_moment_in: wp.array2d(dtype=float), + actuator_force_in: wp.array2d(dtype=float), + # Data out: + qfrc_actuator_out: wp.array2d(dtype=float), ): worldid, actid = wp.tid() @@ -812,26 +798,23 @@ def _qfrc_actuator( for i in range(rownnz): sparseid = rowadr + i colind = moment_colind_in[worldid, sparseid] - qfrc = ( - actuator_moment_in[worldid, sparseid] - * actuator_force_in[worldid, actid] - ) + qfrc = actuator_moment_in[worldid, sparseid] * actuator_force_in[worldid, actid] wp.atomic_add(qfrc_actuator_out[worldid], colind, qfrc) @wp.kernel def _qfrc_actuator_gravcomp_limits( - # Model: - ngravcomp: int, - jnt_actfrclimited: wp.array(dtype=bool), - jnt_actgravcomp: wp.array(dtype=int), - jnt_actfrcrange: wp.array2d(dtype=wp.vec2), - dof_jntid: wp.array(dtype=int), - # Data in: - qfrc_gravcomp_in: wp.array2d(dtype=float), - qfrc_actuator_in: wp.array2d(dtype=float), - # Data out: - qfrc_actuator_out: wp.array2d(dtype=float), + # Model: + ngravcomp: int, + jnt_actfrclimited: wp.array(dtype=bool), + jnt_actgravcomp: wp.array(dtype=int), + jnt_actfrcrange: wp.array2d(dtype=wp.vec2), + dof_jntid: wp.array(dtype=int), + # Data in: + qfrc_gravcomp_in: wp.array2d(dtype=float), + qfrc_actuator_in: wp.array2d(dtype=float), + # Data out: + qfrc_actuator_out: wp.array2d(dtype=float), ): worldid, dofid = wp.tid() jntid = dof_jntid[dofid] @@ -917,30 +900,30 @@ def fwd_actuation(m: Model, d: Data): # TODO(team): optimize performance d.qfrc_actuator.zero_() wp.launch( - _qfrc_actuator, - dim=(d.nworld, m.nu), - inputs=[ - d.moment_rownnz, - d.moment_rowadr, - d.moment_colind, - d.actuator_moment, - d.actuator_force, - ], - outputs=[d.qfrc_actuator], + _qfrc_actuator, + dim=(d.nworld, m.nu), + inputs=[ + d.moment_rownnz, + d.moment_rowadr, + d.moment_colind, + d.actuator_moment, + d.actuator_force, + ], + outputs=[d.qfrc_actuator], ) wp.launch( - _qfrc_actuator_gravcomp_limits, - dim=(d.nworld, m.nv), - inputs=[ - m.ngravcomp, - m.jnt_actfrclimited, - m.jnt_actgravcomp, - m.jnt_actfrcrange, - m.dof_jntid, - d.qfrc_gravcomp, - d.qfrc_actuator, - ], - outputs=[d.qfrc_actuator], + _qfrc_actuator_gravcomp_limits, + dim=(d.nworld, m.nv), + inputs=[ + m.ngravcomp, + m.jnt_actfrclimited, + m.jnt_actgravcomp, + m.jnt_actfrcrange, + m.dof_jntid, + d.qfrc_gravcomp, + d.qfrc_actuator, + ], + outputs=[d.qfrc_actuator], ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py index e516739cb9..0b53094bd2 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py @@ -14,24 +14,24 @@ # ============================================================================== import dataclasses -from typing import Any, Optional, Sequence import warnings +from typing import Any, Optional, Sequence import mujoco +import numpy as np +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src import bvh from mujoco.mjx.third_party.mujoco_warp._src import math as mjmath from mujoco.mjx.third_party.mujoco_warp._src import render_util from mujoco.mjx.third_party.mujoco_warp._src import smooth from mujoco.mjx.third_party.mujoco_warp._src import types from mujoco.mjx.third_party.mujoco_warp._src import warp_util -from mujoco.mjx.third_party.mujoco_warp._src.types import BiasType from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINVAL -from mujoco.mjx.third_party.mujoco_warp._src.types import SPARSE_CONSTRAINT_JACOBIAN +from mujoco.mjx.third_party.mujoco_warp._src.types import BiasType from mujoco.mjx.third_party.mujoco_warp._src.types import TrnType from mujoco.mjx.third_party.mujoco_warp._src.types import vec10 from mujoco.mjx.third_party.mujoco_warp._src.util_pkg import check_version -import numpy as np -import warp as wp def _create_array(data: Any, spec: wp.array, sizes: dict[str, int]) -> wp.array | None: @@ -114,9 +114,6 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: if unsupported: raise NotImplementedError(f"{mj_type(unsupported).name} is unsupported.") - if ((mjm.flex_contype != 0) | (mjm.flex_conaffinity != 0)).any(): - raise NotImplementedError("Flex collisions are not implemented.") - if mjm.opt.noslip_iterations > 0: raise NotImplementedError(f"noslip solver not implemented.") @@ -226,6 +223,8 @@ def _check_friction(name: str, id_: int, condim: int, friction, checks): m.is_sparse = is_sparse(mjm) m.has_fluid = mjm.opt.wind.any() or mjm.opt.density > 0 or mjm.opt.viscosity > 0 + m.max_ten_J_rownnz = int(mjm.ten_J_rownnz.max()) if mjm.ntendon else 0 + # body ids grouped by tree level (depth-based traversal) bodies, body_depth = {}, np.zeros(mjm.nbody, dtype=int) - 1 for i in range(mjm.nbody): @@ -364,6 +363,46 @@ def geom_trid_index(i, j): ) ) + # check for unsupported margin + multicontact / box-box CCD combinations + use_multiccd = mjm.opt.enableflags & types.EnableBit.MULTICCD + nativeccd_disabled = mjm.opt.disableflags & types.DisableBit.NATIVECCD + BOX = int(mujoco.mjtGeom.mjGEOM_BOX) + MESH = int(mujoco.mjtGeom.mjGEOM_MESH) + + has_boxbox = m.geom_pair_type_count[geom_trid_index(BOX, BOX)] > 0 + has_multiccd_pairs = has_boxbox or ( + use_multiccd + and (m.geom_pair_type_count[geom_trid_index(BOX, MESH)] > 0 or m.geom_pair_type_count[geom_trid_index(MESH, MESH)] > 0) + ) + + if has_multiccd_pairs: + + def _check_margin(name, t1, t2, margin): + if use_multiccd: + raise NotImplementedError( + f"{name} has non-zero margin ({margin}) with MULTICCD enabled. Set margin to 0 or disable MULTICCD." + ) + if t1 == BOX and t2 == BOX and not nativeccd_disabled: + raise NotImplementedError( + f"{name} has non-zero margin ({margin}) with NATIVECCD enabled. Set margin to 0 or disable NATIVECCD." + ) + + geom_name = lambda g: mujoco.mj_id2name(mjm, mujoco.mjtObj.mjOBJ_GEOM, g) or str(g) + + for idx in np.nonzero(nxn_include & (nxn_pairid_contact == -1))[0]: + g1, g2 = int(geom1[idx]), int(geom2[idx]) + t1, t2 = int(mjm.geom_type[g1]), int(mjm.geom_type[g2]) + m1, m2 = float(mjm.geom_margin[g1]), float(mjm.geom_margin[g2]) + if (m1 or m2) and t1 in (BOX, MESH) and t2 in (BOX, MESH): + _check_margin(f"geom pair ({geom_name(g1)}, {geom_name(g2)})", t1, t2, (m1, m2)) + + for pid in range(mjm.npair): + g1, g2 = int(mjm.pair_geom1[pid]), int(mjm.pair_geom2[pid]) + t1, t2 = int(mjm.geom_type[g1]), int(mjm.geom_type[g2]) + pm = float(mjm.pair_margin[pid]) + if pm and t1 in (BOX, MESH) and t2 in (BOX, MESH): + _check_margin(f"pair {pid} ({geom_name(g1)}, {geom_name(g2)})", t1, t2, pm) + m.nmaxpolygon = np.append(mjm.mesh_polyvertnum, 0).max() m.nmaxmeshdeg = np.append(mjm.mesh_polymapnum, 0).max() @@ -390,9 +429,11 @@ def geom_trid_index(i, j): current = [] else: current.append(v) - # Pad with zeros if less than 3 - attr_values += [0.0] * (3 - len(attr_values)) - m.plugin_attr.append(attr_values[:3]) + if len(attr_values) > types._NPLUGINATTR: + raise ValueError(f"Plugin has {len(attr_values)} attributes, which exceeds the maximum of {types._NPLUGINATTR}. ") + # pad with zeros to _NPLUGINATTR + attr_values += [0.0] * (types._NPLUGINATTR - len(attr_values)) + m.plugin_attr.append(attr_values[: types._NPLUGINATTR]) # equality constraint addresses m.eq_connect_adr = np.nonzero(mjm.eq_type == types.EqType.CONNECT)[0] @@ -542,6 +583,15 @@ def geom_trid_index(i, j): Madr_ki -= 1 m.qLD_updates = tuple(wp.array(qLD_updates[i], dtype=wp.vec3i) for i in sorted(qLD_updates)) + # Build concatenated updates for fused kernel + all_updates_flat = [] + level_offsets = [0] + for level in sorted(qLD_updates): + all_updates_flat.extend(qLD_updates[level]) + level_offsets.append(len(all_updates_flat)) + m.qLD_all_updates = all_updates_flat if all_updates_flat else [(0, 0, 0)] + m.qLD_level_offsets = level_offsets + # indices for sparse qM_fullm (used in solver) m.qM_fullm_i, m.qM_fullm_j = [], [] for i in range(mjm.nv): @@ -631,9 +681,168 @@ def _default_njmax(mjm: mujoco.MjModel, mjd: Optional[mujoco.MjData] = None) -> return int(valid_sizes[np.searchsorted(valid_sizes, njmax)]) -def _resolve_batch_size( - na: int | None, n: int | None, nworld: int, default: int -) -> int: +def _body_pair_nnz(mjm: mujoco.MjModel, body1: int, body2: int) -> int: + """Returns the number of unique DOFs in the kinematic tree union of two bodies.""" + body1 = mjm.body_weldid[body1] + body2 = mjm.body_weldid[body2] + da1 = mjm.body_dofadr[body1] + mjm.body_dofnum[body1] - 1 + da2 = mjm.body_dofadr[body2] + mjm.body_dofnum[body2] - 1 + nnz = 0 + while da1 >= 0 or da2 >= 0: + da = max(da1, da2) + if da1 == da: + da1 = mjm.dof_parentid[da1] + if da2 == da: + da2 = mjm.dof_parentid[da2] + nnz += 1 + return nnz + + +def _default_njmax_nnz(mjm: mujoco.MjModel, nconmax: int, njmax: int) -> int: + """Returns a heuristic estimate for the number of non-zeros in the sparse constraint Jacobian. + + Assumes all equality, friction, and limit constraints are active and computes + their non-zeros. For contacts, assumes njmax contact rows at the maximum + body-pair non-zeros from all enabled collision pairs. + + Args: + mjm: The model containing kinematic and dynamic information (host). + nconmax: Maximum number of contacts per world. + njmax: Maximum number of constraint rows per world. + + Returns: + Estimated number of non-zeros in the constraint Jacobian. + """ + total_nnz = 0 + + def _eq_bodies(i): + """Returns body pair for equality constraint i.""" + obj1id, obj2id = mjm.eq_obj1id[i], mjm.eq_obj2id[i] + if mjm.eq_objtype[i] == mujoco.mjtObj.mjOBJ_SITE: + return mjm.site_bodyid[obj1id], mjm.site_bodyid[obj2id] + return obj1id, obj2id + + # equality constraints (assume all active) + for i in range(mjm.neq): + eq_type = mjm.eq_type[i] + + if eq_type == mujoco.mjtEq.mjEQ_CONNECT: + total_nnz += 3 * _body_pair_nnz(mjm, *_eq_bodies(i)) + + elif eq_type == mujoco.mjtEq.mjEQ_WELD: + total_nnz += 6 * _body_pair_nnz(mjm, *_eq_bodies(i)) + + elif eq_type == mujoco.mjtEq.mjEQ_JOINT: + total_nnz += 2 if mjm.eq_obj2id[i] >= 0 else 1 + + elif eq_type == mujoco.mjtEq.mjEQ_TENDON: + obj1id = mjm.eq_obj1id[i] + obj2id = mjm.eq_obj2id[i] + rownnz1 = mjm.ten_J_rownnz[obj1id] if obj1id < mjm.ntendon else 0 + if obj2id >= 0 and obj2id < mjm.ntendon: + rowadr1 = mjm.ten_J_rowadr[obj1id] + rowadr2 = mjm.ten_J_rowadr[obj2id] + rownnz2 = mjm.ten_J_rownnz[obj2id] + cols = set() + for j in range(rownnz1): + cols.add(mjm.ten_J_colind[rowadr1 + j]) + for j in range(rownnz2): + cols.add(mjm.ten_J_colind[rowadr2 + j]) + total_nnz += len(cols) + else: + total_nnz += rownnz1 + + elif eq_type == mujoco.mjtEq.mjEQ_FLEX: + obj1id = mjm.eq_obj1id[i] + if obj1id < mjm.nflex: + edge_start = mjm.flex_edgeadr[obj1id] + edge_count = mjm.flex_edgenum[obj1id] + for e in range(edge_count): + total_nnz += mjm.flexedge_J_rownnz[edge_start + e] + + # friction constraints + total_nnz += (mjm.dof_frictionloss > 0).sum() + for i in range(mjm.ntendon): + if mjm.tendon_frictionloss[i] > 0: + total_nnz += mjm.ten_J_rownnz[i] + + # limit constraints (assume all active) + for i in range(mjm.njnt): + if mjm.jnt_limited[i]: + jnt_type = mjm.jnt_type[i] + if jnt_type == mujoco.mjtJoint.mjJNT_BALL: + total_nnz += 3 + elif jnt_type in (mujoco.mjtJoint.mjJNT_SLIDE, mujoco.mjtJoint.mjJNT_HINGE): + total_nnz += 1 + for i in range(mjm.ntendon): + if mjm.tendon_limited[i]: + total_nnz += mjm.ten_J_rownnz[i] + + # contact constraints: njmax rows at max body-pair non-zeros + max_contact_nnz = 0 + + # contact pairs + for i in range(mjm.npair): + g1, g2 = mjm.pair_geom1[i], mjm.pair_geom2[i] + b1, b2 = mjm.geom_bodyid[g1], mjm.geom_bodyid[g2] + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, b1, b2)) + + # filter geom-geom pairs (unique body pairs, filtered) + body_pair_seen = set() + for i in range(mjm.ngeom): + bi = mjm.geom_bodyid[i] + cti, cai = mjm.geom_contype[i], mjm.geom_conaffinity[i] + for j in range(i + 1, mjm.ngeom): + bj = mjm.geom_bodyid[j] + if bi == bj: + continue + if mjm.body_weldid[bi] == 0 and mjm.body_weldid[bj] == 0: + continue + bp = (min(bi, bj), max(bi, bj)) + if bp in body_pair_seen: + continue + ctj, caj = mjm.geom_contype[j], mjm.geom_conaffinity[j] + if not ((cti & caj) or (ctj & cai)): + continue + body_pair_seen.add(bp) + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, bi, bj)) + + # flex vertex contacts + for fi in range(mjm.nflex): + fct = mjm.flex_contype[fi] + fca = mjm.flex_conaffinity[fi] + + vert_start = mjm.flex_vertadr[fi] + vert_count = mjm.flex_vertnum[fi] + flex_bodies = {mjm.flex_vertbodyid[vert_start + v] for v in range(vert_count)} + + geom_bodies = set() + for g in range(mjm.ngeom): + ct, ca = mjm.geom_contype[g], mjm.geom_conaffinity[g] + if (fct & ca) or (ct & fca): + geom_bodies.add(mjm.geom_bodyid[g]) + + for fb in flex_bodies: + for gb in geom_bodies: + if fb != gb: + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, fb, gb)) + + # flex self-collision + if mjm.flex_selfcollide[fi]: + flex_body_list = sorted(flex_bodies) + for idx1 in range(len(flex_body_list)): + for idx2 in range(idx1 + 1, len(flex_body_list)): + max_contact_nnz = max( + max_contact_nnz, + _body_pair_nnz(mjm, flex_body_list[idx1], flex_body_list[idx2]), + ) + + total_nnz += njmax * max_contact_nnz + + return int(min(max(total_nnz, 1), njmax * mjm.nv)) + + +def _resolve_batch_size(na: int | None, n: int | None, nworld: int, default: int) -> int: if na is not None: return na if n is not None: @@ -647,6 +856,7 @@ def make_data( nconmax: Optional[int] = None, nccdmax: Optional[int] = None, njmax: Optional[int] = None, + njmax_nnz: Optional[int] = None, naconmax: Optional[int] = None, naccdmax: Optional[int] = None, ) -> types.Data: @@ -660,6 +870,7 @@ def make_data( nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax. njmax: Number of constraints to allocate per world. Constraint arrays are batched by world: no world may have more than njmax constraints. + njmax_nnz: Number of non-zeros in constraint Jacobian (sparse). Defaults to njmax * nv. naconmax: Number of contacts to allocate for all worlds. Overrides nconmax. naccdmax: Maximum number of CCD contacts. Defaults to naconmax. @@ -709,20 +920,32 @@ def make_data( sizes["naconmax"] = naconmax sizes["njmax"] = njmax + if njmax_nnz is None: + if is_sparse(mjm): + njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax) + else: + njmax_nnz = njmax * mjm.nv + contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)}) + contact.efc_address = wp.array(np.full((naconmax, sizes["nmaxpyramid"]), -1, dtype=int), dtype=int) efc = types.Constraint(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Constraint)}) - if SPARSE_CONSTRAINT_JACOBIAN: + if is_sparse(mjm): efc.J_rownnz = wp.zeros((nworld, njmax), dtype=int) efc.J_rowadr = wp.zeros((nworld, njmax), dtype=int) - efc.J_colind = wp.zeros((nworld, 1, njmax * mjm.nv), dtype=int) - efc.J = wp.zeros((nworld, 1, njmax * mjm.nv), dtype=float) + efc.J_colind = wp.zeros((nworld, 1, njmax_nnz), dtype=int) + efc.J = wp.zeros((nworld, 1, njmax_nnz), dtype=float) else: efc.J_rownnz = wp.zeros((nworld, 0), dtype=int) efc.J_rowadr = wp.zeros((nworld, 0), dtype=int) efc.J_colind = wp.zeros((nworld, 0, 0), dtype=int) efc.J = wp.zeros((nworld, sizes["njmax_pad"], sizes["nv_pad"]), dtype=float) + contact_kwargs = {} + for f in dataclasses.fields(types.Contact): + contact_kwargs[f.name] = _create_array(None, f.type, sizes) + contact = types.Contact(**contact_kwargs) + # world body and static geom (attached to the world) poses are precomputed # this speeds up scenes with many static geoms (e.g. terrains) # TODO(team): remove this when we introduce dof islands + sleeping @@ -734,65 +957,34 @@ def make_data( mocap_id = mjm.body_mocapid[mocap_body] d_kwargs = { - "qpos": wp.array( - np.tile(mjm.qpos0, nworld), shape=(nworld, mjm.nq), dtype=float - ), - "contact": contact, - "efc": efc, - "nworld": nworld, - "naconmax": naconmax, - "naccdmax": naccdmax, - "njmax": njmax, - "njmax_pad": sizes["njmax_pad"], - "qM": None, - "qLD": None, - # world body - "xquat": wp.array( - np.tile(mjd.xquat, (nworld, 1)), - shape=(nworld, mjm.nbody), - dtype=wp.quat, - ), - "xmat": wp.array( - np.tile(mjd.xmat, (nworld, 1)), - shape=(nworld, mjm.nbody), - dtype=wp.mat33, - ), - "ximat": wp.array( - np.tile(mjd.ximat, (nworld, 1)), - shape=(nworld, mjm.nbody), - dtype=wp.mat33, - ), - # static geoms - "geom_xpos": wp.array( - np.tile(mjd.geom_xpos, (nworld, 1)), - shape=(nworld, mjm.ngeom), - dtype=wp.vec3, - ), - "geom_xmat": wp.array( - np.tile(mjd.geom_xmat, (nworld, 1)), - shape=(nworld, mjm.ngeom), - dtype=wp.mat33, - ), - # mocap - "mocap_pos": wp.array( - np.tile(mjm.body_pos[mocap_body[mocap_id]], (nworld, 1)), - shape=(nworld, mjm.nmocap), - dtype=wp.vec3, - ), - "mocap_quat": wp.array( - np.tile(mjm.body_quat[mocap_body[mocap_id]], (nworld, 1)), - shape=(nworld, mjm.nmocap), - dtype=wp.quat, - ), - # equality constraints - "eq_active": wp.array( - np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), - shape=(nworld, mjm.neq), - dtype=bool, - ), - # island arrays - "nisland": None, - "tree_island": None, + "qpos": wp.array(np.tile(mjm.qpos0, nworld), shape=(nworld, mjm.nq), dtype=float), + "contact": contact, + "efc": efc, + "nworld": nworld, + "naconmax": naconmax, + "naccdmax": naccdmax, + "njmax": njmax, + "njmax_pad": sizes["njmax_pad"], + "njmax_nnz": njmax_nnz, + "qM": None, + "qLD": None, + # world body + "xquat": wp.array(np.tile(mjd.xquat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.quat), + "xmat": wp.array(np.tile(mjd.xmat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.mat33), + "ximat": wp.array(np.tile(mjd.ximat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.mat33), + # static geoms + "geom_xpos": wp.array(np.tile(mjd.geom_xpos, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.vec3), + "geom_xmat": wp.array(np.tile(mjd.geom_xmat, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.mat33), + # mocap + "mocap_pos": wp.array(np.tile(mjm.body_pos[mocap_body[mocap_id]], (nworld, 1)), shape=(nworld, mjm.nmocap), dtype=wp.vec3), + "mocap_quat": wp.array( + np.tile(mjm.body_quat[mocap_body[mocap_id]], (nworld, 1)), shape=(nworld, mjm.nmocap), dtype=wp.quat + ), + # equality constraints + "eq_active": wp.array(np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), shape=(nworld, mjm.neq), dtype=bool), + # island arrays + "nisland": None, + "tree_island": None, } for f in dataclasses.fields(types.Data): if f.name in d_kwargs: @@ -822,6 +1014,7 @@ def put_data( nconmax: Optional[int] = None, nccdmax: Optional[int] = None, njmax: Optional[int] = None, + njmax_nnz: Optional[int] = None, naconmax: Optional[int] = None, naccdmax: Optional[int] = None, ) -> types.Data: @@ -836,6 +1029,7 @@ def put_data( nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax. njmax: Number of constraints to allocate per world. Constraint arrays are batched by world: no world may have more than njmax constraints. + njmax_nnz: Number of non-zeros in constraint Jacobian (sparse). Defaults to njmax * nv. naconmax: Number of contacts to allocate for all worlds. Overrides nconmax. naccdmax: Maximum number of CCD contacts. Defaults to naconmax. @@ -898,6 +1092,12 @@ def put_data( sizes["naconmax"] = naconmax sizes["njmax"] = njmax + if njmax_nnz is None: + if is_sparse(mjm): + njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax) + else: + njmax_nnz = njmax * mjm.nv + # ensure static geom positions are computed # TODO: remove once MjData creation semantics are fixed mujoco.mj_kinematics(mjm, mjd) @@ -915,7 +1115,7 @@ def put_data( contact = types.Contact(**contact_kwargs) - contact.efc_address = np.zeros((naconmax, sizes["nmaxpyramid"]), dtype=int) + contact.efc_address = np.full((naconmax, sizes["nmaxpyramid"]), -1, dtype=int) for i in range(mjd.ncon): efc_address = mjd.contact.efc_address[i] if efc_address == -1: @@ -945,43 +1145,28 @@ def put_data( efc = types.Constraint(**efc_kwargs) - if SPARSE_CONSTRAINT_JACOBIAN: - # TODO(team): process efc_J sparsity structure for nv row shift - efc.J_rownnz = wp.array( - np.full((nworld, njmax), mjm.nv, dtype=int), dtype=int - ) - efc.J_rowadr = wp.array( - np.tile( - np.arange(0, njmax * mjm.nv, mjm.nv) - if mjm.nv - else np.zeros(njmax, dtype=int), - (nworld, 1), - ), - dtype=int, - ) - efc.J_colind = wp.array( - np.tile(np.arange(mjm.nv), (nworld, njmax)).reshape((nworld, 1, -1)), - dtype=int, - ) - - mj_efc_J = np.zeros((mjd.nefc, mjm.nv)) + if is_sparse(mjm): + J_rownnz = np.zeros(njmax, dtype=np.int32) + J_rowadr = np.zeros(njmax, dtype=np.int32) + J_colind = np.zeros(njmax_nnz, dtype=np.int32) + J = np.zeros(njmax_nnz, dtype=np.float64) if mjd.nefc: if mujoco.mj_isSparse(mjm): - mujoco.mju_sparse2dense( - mj_efc_J, - mjd.efc_J, - mjd.efc_J_rownnz, - mjd.efc_J_rowadr, - mjd.efc_J_colind, - ) + J_rownnz[: mjd.nefc] = mjd.efc_J_rownnz[: mjd.nefc] + J_rowadr[: mjd.nefc] = mjd.efc_J_rowadr[: mjd.nefc] + nnz = int(mjd.efc_J_rownnz[: mjd.nefc].sum()) + J_colind[:nnz] = mjd.efc_J_colind[:nnz] + J[:nnz] = mjd.efc_J[:nnz] else: - mj_efc_J = mjd.efc_J.reshape((mjd.nefc, mjm.nv)) - efc_J = np.zeros((njmax, mjm.nv), dtype=float) - efc_J[: mjd.nefc, : mjm.nv] = mj_efc_J - efc.J = wp.array( - np.tile(efc_J.reshape(-1), (nworld, 1, 1)).reshape((nworld, 1, -1)), - dtype=float, - ) + dense_J = mjd.efc_J.reshape((-1, mjm.nv))[: mjd.nefc] + mujoco.mju_dense2sparse( + J[: mjd.nefc * mjm.nv], dense_J, J_rownnz[: mjd.nefc], J_rowadr[: mjd.nefc], J_colind[: mjd.nefc * mjm.nv] + ) + + efc.J_rownnz = wp.array(np.tile(J_rownnz, (nworld, 1)), dtype=int) + efc.J_rowadr = wp.array(np.tile(J_rowadr, (nworld, 1)), dtype=int) + efc.J_colind = wp.array(np.tile(J_colind, (nworld, 1)).reshape((nworld, 1, -1)), dtype=int) + efc.J = wp.array(np.tile(J, (nworld, 1)).reshape((nworld, 1, -1)), dtype=float) else: efc.J_rownnz = wp.zeros((nworld, 0), dtype=int) efc.J_rowadr = wp.zeros((nworld, 0), dtype=int) @@ -990,13 +1175,7 @@ def put_data( mj_efc_J = np.zeros((mjd.nefc, mjm.nv)) if mjd.nefc: if mujoco.mj_isSparse(mjm): - mujoco.mju_sparse2dense( - mj_efc_J, - mjd.efc_J, - mjd.efc_J_rownnz, - mjd.efc_J_rowadr, - mjd.efc_J_colind, - ) + mujoco.mju_sparse2dense(mj_efc_J, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind) else: mj_efc_J = mjd.efc_J.reshape((mjd.nefc, mjm.nv)) efc_J = np.zeros((nworld, sizes["njmax_pad"], sizes["nv_pad"]), dtype=float) @@ -1005,22 +1184,22 @@ def put_data( # create data d_kwargs = { - "contact": contact, - "efc": efc, - "nworld": nworld, - "naconmax": naconmax, - "naccdmax": naccdmax, - "njmax": njmax, - "njmax_pad": sizes["njmax_pad"], - # fields set after initialization: - "solver_niter": None, - "qM": None, - "qLD": None, - "ten_J": None, - "nacon": None, - # island arrays - "nisland": None, - "tree_island": None, + "contact": contact, + "efc": efc, + "nworld": nworld, + "naconmax": naconmax, + "naccdmax": naccdmax, + "njmax": njmax, + "njmax_pad": sizes["njmax_pad"], + "njmax_nnz": njmax_nnz, + # fields set after initialization: + "solver_niter": None, + "qM": None, + "qLD": None, + "nacon": None, + # island arrays + "nisland": None, + "tree_island": None, } for f in dataclasses.fields(types.Data): if f.name in d_kwargs: @@ -1050,29 +1229,6 @@ def put_data( d.nisland = wp.array(np.full(nworld, mjd.nisland), dtype=int) d.tree_island = wp.array(np.tile(mjd.tree_island, (nworld, 1)), dtype=int) - ten_J = np.zeros((mjm.ntendon, mjm.nv)) - if mujoco.mj_isSparse(mjm) or check_version("mujoco>=3.5.1.dev872479828"): - if mjm.ntendon: - if check_version("mujoco>=3.5.1.dev875093374"): - mujoco.mju_sparse2dense( - ten_J, - mjd.ten_J.reshape(-1), - mjm.ten_J_rownnz, - mjm.ten_J_rowadr, - mjm.ten_J_colind.reshape(-1), - ) - else: - mujoco.mju_sparse2dense( - ten_J, - mjd.ten_J.reshape(-1), - mjd.ten_J_rownnz, - mjd.ten_J_rowadr, - mjd.ten_J_colind.reshape(-1), - ) - else: - ten_J = mjd.ten_J.reshape((mjm.ntendon, mjm.nv)) - d.ten_J = wp.array(np.full((nworld, mjm.ntendon, mjm.nv), ten_J), dtype=float) - d.nacon = wp.array([mjd.ncon * nworld], dtype=int) return d @@ -1233,14 +1389,14 @@ def get_data_into( mujoco.mj_factorM(mjm, result) if nefc > 0: - if SPARSE_CONSTRAINT_JACOBIAN: + if is_sparse(mjm): efc_J = np.zeros((nefc, mjm.nv)) mujoco.mju_sparse2dense( - efc_J, - d.efc.J.numpy()[world_id, 0], - d.efc.J_rownnz.numpy()[world_id, :nefc], - d.efc.J_rowadr.numpy()[world_id, :nefc], - d.efc.J_colind.numpy()[world_id, 0], + efc_J, + d.efc.J.numpy()[world_id, 0], + d.efc.J_rownnz.numpy()[world_id, :nefc], + d.efc.J_rowadr.numpy()[world_id, :nefc], + d.efc.J_colind.numpy()[world_id, 0], ) else: efc_J = d.efc.J.numpy()[world_id, :nefc, : mjm.nv] @@ -1248,11 +1404,11 @@ def get_data_into( # write to mujoco result (format depends on mj_isSparse) if mujoco.mj_isSparse(mjm): mujoco.mju_dense2sparse( - result.efc_J, - efc_J[efc_idx], - result.efc_J_rownnz, - result.efc_J_rowadr, - result.efc_J_colind, + result.efc_J, + efc_J[efc_idx], + result.efc_J_rownnz, + result.efc_J_rowadr, + result.efc_J_colind, ) else: result.efc_J[: nefc * mjm.nv] = efc_J[efc_idx].flatten() @@ -1276,24 +1432,7 @@ def get_data_into( # tendon result.ten_length[:] = d.ten_length.numpy()[world_id] - if check_version("mujoco>=3.5.1.dev869712136"): - ten_J = d.ten_J.numpy()[world_id] - if check_version("mujoco>=3.5.1.dev875093374"): - ten_J_rownnz = mjm.ten_J_rownnz - ten_J_rowadr = mjm.ten_J_rowadr - ten_J_colind = mjm.ten_J_colind.reshape(-1) - else: - ten_J_rownnz = result.ten_J_rownnz - ten_J_rowadr = result.ten_J_rowadr - ten_J_colind = result.ten_J_colind.reshape(-1) - mujoco.mju_dense2sparse( - result.ten_J, - ten_J, - ten_J_rownnz, - ten_J_rowadr, - ten_J_colind, - ) - else: + if mjm.ntendon > 0: result.ten_J[:] = d.ten_J.numpy()[world_id] result.ten_wrapadr[:] = d.ten_wrapadr.numpy()[world_id] result.ten_wrapnum[:] = d.ten_wrapnum.numpy()[world_id] @@ -1428,12 +1567,8 @@ def reset_mocap( mocapid = body_mocapid[bodyid] if mocapid >= 0: - mocap_pos_out[worldid, mocapid] = body_pos[ - worldid % body_pos.shape[0], bodyid - ] - mocap_quat_out[worldid, mocapid] = body_quat[ - worldid % body_quat.shape[0], bodyid - ] + mocap_pos_out[worldid, mocapid] = body_pos[worldid % body_pos.shape[0], bodyid] + mocap_quat_out[worldid, mocapid] = body_quat[worldid % body_quat.shape[0], bodyid] @wp.kernel(module="unique", enable_backward=False) def reset_contact( @@ -1453,6 +1588,8 @@ def reset_contact( contact_solimp_out: wp.array(dtype=types.vec5), contact_dim_out: wp.array(dtype=int), contact_geom_out: wp.array(dtype=wp.vec2i), + contact_flex_out: wp.array(dtype=wp.vec2i), + contact_vert_out: wp.array(dtype=wp.vec2i), contact_efc_address_out: wp.array2d(dtype=int), contact_worldid_out: wp.array(dtype=int), contact_type_out: wp.array(dtype=int), @@ -1479,8 +1616,10 @@ def reset_contact( contact_solimp_out[conid] = types.vec5(0.0, 0.0, 0.0, 0.0, 0.0) contact_dim_out[conid] = 0 contact_geom_out[conid] = wp.vec2i(0, 0) + contact_flex_out[conid] = wp.vec2i(0, 0) + contact_vert_out[conid] = wp.vec2i(0, 0) for i in range(nefcaddress): - contact_efc_address_out[conid, i] = 0 + contact_efc_address_out[conid, i] = -1 contact_worldid_out[conid] = 0 contact_type_out[conid] = 0 contact_geomcollisionid_out[conid] = 0 @@ -1519,6 +1658,8 @@ def reset_contact( d.contact.solimp, d.contact.dim, d.contact.geom, + d.contact.flex, + d.contact.vert, d.contact.efc_address, d.contact.worldid, d.contact.type, @@ -1812,28 +1953,45 @@ def _finalize_body_invweight0( @wp.kernel def _copy_tendon_jacobian( tenid_target: int, - ten_J_in: wp.array3d(dtype=float), + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + ten_J_in: wp.array2d(dtype=float), ten_J_vec_out: wp.array2d(dtype=float), ): worldid = wp.tid() nv = ten_J_in.shape[2] - for i in range(nv): - ten_J_vec_out[worldid, i] = ten_J_in[worldid, tenid_target, i] + rownnz = ten_J_rownnz[tenid_target] + rowadr = ten_J_rowadr[tenid_target] + for i in range(rownnz): + colind = ten_J_colind[rowadr + i] + ten_J_vec_out[worldid, colind] = ten_J_in[worldid, rowadr + i] @wp.kernel def _compute_tendon_dot_product( + # Model: + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + # In: tenid_target: int, - nv: int, - ten_J_in: wp.array3d(dtype=float), + ten_J_in: wp.array2d(dtype=float), result_vec_in: wp.array2d(dtype=float), + # Out: tendon_invweight0_out: wp.array2d(dtype=float), ): worldid = wp.tid() tendon_invweight0_id = worldid % tendon_invweight0_out.shape[0] dot_prod = float(0.0) - for i in range(nv): - dot_prod += ten_J_in[worldid, tenid_target, i] * result_vec_in[worldid, i] + + rownnz = ten_J_rownnz[tenid_target] + rowadr = ten_J_rowadr[tenid_target] + for i in range(rownnz): + sparseid = rowadr + i + colind = ten_J_colind[sparseid] + dot_prod += ten_J_in[worldid, sparseid] * result_vec_in[worldid, colind] + tendon_invweight0_out[tendon_invweight0_id, tenid_target] = dot_prod @@ -1891,12 +2049,12 @@ def _compute_light_pos0( @wp.kernel def _copy_actuator_moment( - actid_target: int, - moment_rownnz_in: wp.array2d(dtype=int), - moment_rowadr_in: wp.array2d(dtype=int), - moment_colind_in: wp.array2d(dtype=int), - actuator_moment_in: wp.array2d(dtype=float), - act_moment_vec_out: wp.array2d(dtype=float), + actid_target: int, + moment_rownnz_in: wp.array2d(dtype=int), + moment_rowadr_in: wp.array2d(dtype=int), + moment_colind_in: wp.array2d(dtype=int), + actuator_moment_in: wp.array2d(dtype=float), + act_moment_vec_out: wp.array2d(dtype=float), ): worldid = wp.tid() nv = act_moment_vec_out.shape[1] @@ -1912,10 +2070,10 @@ def _copy_actuator_moment( @wp.kernel def _compute_actuator_acc0( - actid_target: int, - nv: int, - result_vec_in: wp.array2d(dtype=float), - actuator_acc0_out: wp.array2d(dtype=float), + actid_target: int, + nv: int, + result_vec_in: wp.array2d(dtype=float), + actuator_acc0_out: wp.array2d(dtype=float), ): worldid = wp.tid() norm_sq = float(0.0) @@ -1926,11 +2084,11 @@ def _compute_actuator_acc0( @wp.kernel def _compute_dof_M0( - dof_bodyid: wp.array(dtype=int), - dof_armature: wp.array2d(dtype=float), - cdof_in: wp.array2d(dtype=wp.spatial_vector), - crb_in: wp.array2d(dtype=vec10), - dof_M0_out: wp.array2d(dtype=float), + dof_bodyid: wp.array(dtype=int), + dof_armature: wp.array2d(dtype=float), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + crb_in: wp.array2d(dtype=vec10), + dof_M0_out: wp.array2d(dtype=float), ): worldid, dofid = wp.tid() bodyid = dof_bodyid[dofid] @@ -1941,15 +2099,15 @@ def _compute_dof_M0( @wp.kernel def _resolve_dampratio( - actuator_biastype: wp.array(dtype=int), - actuator_gainprm: wp.array2d(dtype=types.vec10f), - moment_rownnz_in: wp.array2d(dtype=int), - moment_rowadr_in: wp.array2d(dtype=int), - moment_colind_in: wp.array2d(dtype=int), - actuator_moment_in: wp.array2d(dtype=float), - dof_M0_in: wp.array2d(dtype=float), - nv: int, - actuator_biasprm: wp.array2d(dtype=types.vec10f), + actuator_biastype: wp.array(dtype=int), + actuator_gainprm: wp.array2d(dtype=types.vec10f), + moment_rownnz_in: wp.array2d(dtype=int), + moment_rowadr_in: wp.array2d(dtype=int), + moment_colind_in: wp.array2d(dtype=int), + actuator_moment_in: wp.array2d(dtype=float), + dof_M0_in: wp.array2d(dtype=float), + nv: int, + actuator_biasprm: wp.array2d(dtype=types.vec10f), ): worldid, actid = wp.tid() biastype = actuator_biastype[actid] @@ -1992,15 +2150,15 @@ def _resolve_dampratio( @wp.kernel def _set_length_range( - actuator_trntype: wp.array(dtype=int), - actuator_trnid: wp.array(dtype=wp.vec2i), - actuator_gear: wp.array2d(dtype=wp.spatial_vector), - jnt_limited: wp.array(dtype=int), - jnt_range: wp.array2d(dtype=wp.vec2), - tendon_limited: wp.array(dtype=int), - tendon_range: wp.array2d(dtype=wp.vec2), - ntendon: int, - actuator_lengthrange_out: wp.array2d(dtype=wp.vec2), + actuator_trntype: wp.array(dtype=int), + actuator_trnid: wp.array(dtype=wp.vec2i), + actuator_gear: wp.array2d(dtype=wp.spatial_vector), + jnt_limited: wp.array(dtype=int), + jnt_range: wp.array2d(dtype=wp.vec2), + tendon_limited: wp.array(dtype=int), + tendon_range: wp.array2d(dtype=wp.vec2), + ntendon: int, + actuator_lengthrange_out: wp.array2d(dtype=wp.vec2), ): worldid, actid = wp.tid() trntype = actuator_trntype[actid] @@ -2167,16 +2325,22 @@ def set_const_0(m: types.Model, d: types.Data): # tendon_invweight0[t] = J_t * inv(M) * J_t' if m.ntendon > 0: - ten_J_vec = wp.zeros((d.nworld, m.nv), dtype=float) - ten_result_vec = wp.zeros((d.nworld, m.nv), dtype=float) + ten_J_vec = wp.empty((d.nworld, m.nv), dtype=float) + ten_result_vec = wp.empty((d.nworld, m.nv), dtype=float) for tenid in range(m.ntendon): - wp.launch(_copy_tendon_jacobian, dim=d.nworld, inputs=[tenid, d.ten_J], outputs=[ten_J_vec]) + ten_J_vec.zero_() + wp.launch( + _copy_tendon_jacobian, + dim=d.nworld, + inputs=[tenid, m.ten_J_rownnz, m.ten_J_rowadr, m.ten_J_colind, d.ten_J], + outputs=[ten_J_vec], + ) smooth.solve_m(m, d, ten_result_vec, ten_J_vec) wp.launch( _compute_tendon_dot_product, dim=d.nworld, - inputs=[tenid, m.nv, d.ten_J, ten_result_vec], + inputs=[m.ten_J_rownnz, m.ten_J_rowadr, m.ten_J_colind, tenid, d.ten_J, ten_result_vec], outputs=[m.tendon_invweight0], ) @@ -2201,16 +2365,10 @@ def set_const_0(m: types.Model, d: types.Data): for actid in range(m.nu): wp.launch( - _copy_actuator_moment, - dim=d.nworld, - inputs=[ - actid, - d.moment_rownnz, - d.moment_rowadr, - d.moment_colind, - d.actuator_moment, - ], - outputs=[act_moment_vec], + _copy_actuator_moment, + dim=d.nworld, + inputs=[actid, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment], + outputs=[act_moment_vec], ) smooth.solve_m(m, d, act_result_vec, act_moment_vec) wp.launch(_compute_actuator_acc0, dim=d.nworld, inputs=[actid, m.nv, act_result_vec], outputs=[m.actuator_acc0]) @@ -2219,25 +2377,25 @@ def set_const_0(m: types.Model, d: types.Data): if m.nu > 0 and m.nv > 0: dof_M0 = wp.zeros((d.nworld, m.nv), dtype=float) wp.launch( - _compute_dof_M0, - dim=(d.nworld, m.nv), - inputs=[m.dof_bodyid, m.dof_armature, d.cdof, d.crb], - outputs=[dof_M0], + _compute_dof_M0, + dim=(d.nworld, m.nv), + inputs=[m.dof_bodyid, m.dof_armature, d.cdof, d.crb], + outputs=[dof_M0], ) wp.launch( - _resolve_dampratio, - dim=(d.nworld, m.nu), - inputs=[ - m.actuator_biastype, - m.actuator_gainprm, - d.moment_rownnz, - d.moment_rowadr, - d.moment_colind, - d.actuator_moment, - dof_M0, - m.nv, - ], - outputs=[m.actuator_biasprm], + _resolve_dampratio, + dim=(d.nworld, m.nu), + inputs=[ + m.actuator_biastype, + m.actuator_gainprm, + d.moment_rownnz, + d.moment_rowadr, + d.moment_colind, + d.actuator_moment, + dof_M0, + m.nv, + ], + outputs=[m.actuator_biasprm], ) wp.copy(d.qpos, qpos_saved) @@ -2255,16 +2413,12 @@ def set_const(m: types.Model, d: types.Data): Field | Notes ---------------------------------|---------------------------------------------- qpos0, qpos_spring | - body_mass, body_inertia, | Mass and inertia are usually scaled - together + body_mass, body_inertia, | Mass and inertia are usually scaled together body_ipos, body_iquat | since inertia is sum(m * r^2). - body_pos, body_quat | Unsafe for static bodies (invalidates - BVH). - body_gravcomp | If changing from 0 to >0 bodies, - required. + body_pos, body_quat | Unsafe for static bodies (invalidates BVH). + body_gravcomp | If changing from 0 to >0 bodies, required. dof_armature | - eq_data | For connect/weld, offsets computed if not - set. + eq_data | For connect/weld, offsets computed if not set. hfield_size | tendon_stiffness, tendon_damping | Only if changing from/to zero. actuator_gainprm, actuator_biasprm | For position actuators with dampratio. @@ -2319,19 +2473,19 @@ def set_length_range(m: types.Model, d: types.Data, index: int = -1): return wp.launch( - _set_length_range, - dim=(d.nworld, m.nu), - inputs=[ - m.actuator_trntype, - m.actuator_trnid, - m.actuator_gear, - m.jnt_limited, - m.jnt_range, - m.tendon_limited, - m.tendon_range, - m.ntendon, - ], - outputs=[m.actuator_lengthrange], + _set_length_range, + dim=(d.nworld, m.nu), + inputs=[ + m.actuator_trntype, + m.actuator_trnid, + m.actuator_gear, + m.jnt_limited, + m.jnt_range, + m.tendon_limited, + m.tendon_range, + m.ntendon, + ], + outputs=[m.actuator_lengthrange], ) @@ -2486,6 +2640,7 @@ def create_render_context( cam_res: list[tuple[int, int]] | tuple[int, int] | None = None, render_rgb: list[bool] | bool | None = None, render_depth: list[bool] | bool | None = None, + render_seg: list[bool] | bool | None = None, use_textures: bool = True, use_shadows: bool = False, enabled_geom_groups: list[int] = [0, 1, 2], @@ -2502,6 +2657,8 @@ def create_render_context( MuJoCo model values. render_rgb: Whether to render RGB images. If None, uses the MuJoCo model values. render_depth: Whether to render depth images. If None, uses the MuJoCo model values. + render_seg: Whether to render segmentation (per-pixel geom IDs). If None, + uses the MuJoCo model values. use_textures: Whether to use textures. use_shadows: Whether to use shadows. enabled_geom_groups: The geom groups to render. @@ -2517,10 +2674,13 @@ def create_render_context( mjd = mujoco.MjData(mjm) mujoco.mj_forward(mjm, mjd) - # TODO(team): remove after mjwarp depends on warp-lang >= 1.12 in pyproject.toml - if use_textures and not hasattr(wp, "Texture2D"): - warnings.warn("Textures require warp >= 1.12. Disabling textures.") - use_textures = False + constructor = "sah" + if check_version("warp>=1.13.0.dev20260325"): + # TODO: The cubql constructor and is_cubql_available exist only in + # recent Warp 1.13+ builds, modify this after warp is updated to 1.13+. + _cubql_avail = getattr(wp, "is_cubql_available", None) + if callable(_cubql_avail) and _cubql_avail(): + constructor = "cubql" # Mesh BVHs nmesh = mjm.nmesh @@ -2534,7 +2694,7 @@ def create_render_context( mesh_bounds_size = [wp.vec3(0.0, 0.0, 0.0) for _ in range(nmesh)] for mid in used_mesh_id: - mesh, half = bvh.build_mesh_bvh(mjm, mid) + mesh, half = bvh.build_mesh_bvh(mjm, mid, constructor=constructor) mesh_registry[mesh.id] = mesh mesh_bvh_id[mid] = mesh.id mesh_bounds_size[mid] = half @@ -2551,7 +2711,7 @@ def create_render_context( hfield_bounds_size = [wp.vec3(0.0, 0.0, 0.0) for _ in range(nhfield)] for hid in used_hfield_id: - hmesh, hhalf = bvh.build_hfield_bvh(mjm, hid) + hmesh, hhalf = bvh.build_hfield_bvh(mjm, hid, constructor=constructor) hfield_registry[hmesh.id] = hmesh hfield_bvh_id[hid] = hmesh.id hfield_bounds_size[hid] = hhalf @@ -2560,65 +2720,33 @@ def create_render_context( hfield_bounds_size_arr = wp.array(hfield_bounds_size, dtype=wp.vec3) # Flex BVHs - flex_bvh_id = wp.uint64(0) - flex_group_root = wp.zeros(nworld, dtype=int) - flex_mesh = None - flex_face_point = None - flex_elemdataadr = None - flex_shell = None - flex_shelldataadr = None - flex_faceadr = None - flex_nface = 0 - flex_radius = None - flex_workadr = None - flex_worknum = None - flex_nwork = 0 - - if mjm.nflex > 0: - ( - fmesh, - face_point, - flex_group_roots, - flex_shell_data, - flex_faceadr_data, - flex_nface, - ) = bvh.build_flex_bvh(mjm, mjd, nworld) - - flex_mesh = fmesh - flex_bvh_id = fmesh.id - flex_face_point = face_point - flex_group_root = flex_group_roots - flex_elemdataadr = wp.array(mjm.flex_elemdataadr, dtype=int) - flex_shell = flex_shell_data - flex_shelldataadr = wp.array(mjm.flex_shelldataadr, dtype=int) - flex_faceadr = wp.array(flex_faceadr_data, dtype=int) - flex_radius = wp.array(mjm.flex_radius, dtype=float) - - # precompute work item layout for unified refit kernel - nflex = mjm.nflex - workadr = np.zeros(nflex, dtype=np.int32) - worknum = np.zeros(nflex, dtype=np.int32) - cumsum = 0 - for f in range(nflex): - workadr[f] = cumsum - if mjm.flex_dim[f] == 2: - worknum[f] = mjm.flex_elemnum[f] + mjm.flex_shellnum[f] - else: - worknum[f] = mjm.flex_shellnum[f] - cumsum += worknum[f] - flex_workadr = wp.array(workadr, dtype=int) - flex_worknum = wp.array(worknum, dtype=int) - flex_nwork = int(cumsum) + nflex = mjm.nflex + flex_registry = {} + + # Scene BVH flex primitives: 1D → one capsule per edge, 2D/3D → one box per flex + flex_geom_flexid = [] + flex_geom_edgeid = [] + flex_bvh_id = np.full(nflex, 0, dtype=wp.uint64) + flex_group_root = np.zeros((nflex, nworld), dtype=int) + + for f in range(nflex): + if mjm.flex_dim[f] == 1: + edge_adr = mjm.flex_edgeadr[f] + flex_geom_flexid.extend([f] * mjm.flex_edgenum[f]) + flex_geom_edgeid.extend([edge_adr + e for e in range(mjm.flex_edgenum[f])]) + flex_group_root[f] = np.zeros(nworld, dtype=int) + else: + flex_geom_flexid.append(f) + flex_geom_edgeid.append(-1) + fmesh, group_root = bvh.build_flex_bvh(mjm, mjd, nworld, f) + flex_registry[f] = fmesh + flex_bvh_id[f] = fmesh.id + flex_group_root[f] = group_root.numpy() textures_registry = [] - # TODO: remove after mjwarp depends on warp-lang >= 1.12 in pyproject.toml - if hasattr(wp, "Texture2D"): - for i in range(mjm.ntex): - textures_registry.append(render_util.create_warp_texture(mjm, i)) - textures = wp.array(textures_registry, dtype=wp.Texture2D) - else: - # Dummy array when texture support isn't available (warp < 1.12) - textures = wp.zeros(1, dtype=int) + for i in range(mjm.ntex): + textures_registry.append(render_util.create_warp_texture(mjm, i)) + textures = wp.array(textures_registry, dtype=wp.Texture2D) # Filter active cameras if cam_active is not None: @@ -2642,31 +2770,31 @@ def create_render_context( cam_res_arr = wp.array(active_cam_res, dtype=wp.vec2i) if render_rgb is None: - render_rgb = [ - mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_RGB - for i in active_cam_indices - ] + render_rgb = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_RGB for i in active_cam_indices] elif isinstance(render_rgb, bool): render_rgb = [render_rgb] * ncam if render_depth is None: - render_depth = [ - mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_DEPTH - for i in active_cam_indices - ] + render_depth = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_DEPTH for i in active_cam_indices] if isinstance(render_depth, bool): render_depth = [render_depth] * ncam - assert len(render_rgb) == ncam and len(render_depth) == ncam, ( - "render_rgb and render_depth must be a bool or a list of bools with" - f" length {ncam}" + if render_seg is None: + render_seg = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_SEG for i in active_cam_indices] + elif isinstance(render_seg, bool): + render_seg = [render_seg] * ncam + + assert len(render_rgb) == ncam and len(render_depth) == ncam and len(render_seg) == ncam, ( + f"render_rgb, render_depth, and render_seg must be a bool or a list of bools with length {ncam}" ) rgb_adr = -1 * np.ones(ncam, dtype=int) depth_adr = -1 * np.ones(ncam, dtype=int) + seg_adr = -1 * np.ones(ncam, dtype=int) cam_res_np = cam_res_arr.numpy() ri = 0 di = 0 + si = 0 total = 0 for idx in range(ncam): @@ -2676,6 +2804,9 @@ def create_render_context( if render_depth[idx]: depth_adr[idx] = di di += cam_res_np[idx][0] * cam_res_np[idx][1] + if render_seg[idx]: + seg_adr[idx] = si + si += cam_res_np[idx][0] * cam_res_np[idx][1] total += cam_res_np[idx][0] * cam_res_np[idx][1] @@ -2729,26 +2860,20 @@ def create_render_context( hfield_registry=hfield_registry, hfield_bvh_id=hfield_bvh_id_arr, hfield_bounds_size=hfield_bounds_size_arr, - flex_mesh=flex_mesh, + flex_mesh_registry=flex_registry, flex_rgba=wp.array(mjm.flex_rgba, dtype=wp.vec4), - flex_bvh_id=flex_bvh_id, - flex_face_point=flex_face_point, - flex_faceadr=flex_faceadr, - flex_nface=flex_nface, - flex_nwork=flex_nwork, - flex_group_root=flex_group_root, - flex_elemdataadr=flex_elemdataadr, - flex_shell=flex_shell, - flex_shelldataadr=flex_shelldataadr, - flex_radius=flex_radius, - flex_workadr=flex_workadr, - flex_worknum=flex_worknum, + flex_bvh_id=wp.array(flex_bvh_id, dtype=wp.uint64), + flex_group_root=wp.array(flex_group_root, dtype=int), flex_render_smooth=flex_render_smooth, + bvh_nflexgeom=len(flex_geom_flexid), + flex_dim_np=mjm.flex_dim, + flex_geom_flexid=wp.array(flex_geom_flexid, dtype=int), + flex_geom_edgeid=wp.array(flex_geom_edgeid, dtype=int), bvh=None, bvh_id=None, - lower=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3), - upper=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3), - group=wp.zeros(nworld * bvh_ngeom, dtype=int), + lower=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=wp.vec3), + upper=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=wp.vec3), + group=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=int), group_root=wp.zeros(nworld, dtype=int), ray=ray, rgb_data=wp.zeros((nworld, ri), dtype=wp.uint32), @@ -2757,6 +2882,9 @@ def create_render_context( depth_adr=wp.array(depth_adr, dtype=int), render_rgb=wp.array(render_rgb, dtype=bool), render_depth=wp.array(render_depth, dtype=bool), + seg_data=wp.zeros((nworld, max(si, 1)), dtype=int), + seg_adr=wp.array(seg_adr, dtype=int), + render_seg=wp.array(render_seg, dtype=bool), znear=znear, total_rays=int(total), ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/island.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/island.py index b5ae2846eb..c021db3f01 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/island.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/island.py @@ -13,12 +13,13 @@ # limitations under the License. # ============================================================================== +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src import types from mujoco.mjx.third_party.mujoco_warp._src.types import ConstraintType from mujoco.mjx.third_party.mujoco_warp._src.types import EqType from mujoco.mjx.third_party.mujoco_warp._src.types import ObjType from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope -import warp as wp @wp.kernel @@ -180,17 +181,17 @@ def tree_edges(m: types.Model, d: types.Data, tree_tree: wp.array3d(dtype=int)): @wp.kernel def _flood_fill( - # Model: - ntree: int, - # In: - tree_tree_in: wp.array3d(dtype=int), - labels_in: wp.array2d(dtype=int), - stack_in: wp.array2d(dtype=int), - # Data out: - nisland_out: wp.array(dtype=int), - tree_island_out: wp.array2d(dtype=int), - # Out: - stack_out: wp.array2d(dtype=int), + # Model: + ntree: int, + # In: + tree_tree_in: wp.array3d(dtype=int), + labels_in: wp.array2d(dtype=int), + stack_in: wp.array2d(dtype=int), + # Data out: + nisland_out: wp.array(dtype=int), + tree_island_out: wp.array2d(dtype=int), + # Out: + stack_out: wp.array2d(dtype=int), ): """DFS flood fill to discover islands using tree_tree matrix.""" worldid = wp.tid() @@ -257,8 +258,8 @@ def island(m: types.Model, d: types.Data): stack_scratch = wp.empty((d.nworld, m.ntree * m.ntree), dtype=int) wp.launch( - _flood_fill, - dim=d.nworld, - inputs=[m.ntree, tree_tree, d.tree_island, stack_scratch], - outputs=[d.nisland, d.tree_island, stack_scratch], + _flood_fill, + dim=d.nworld, + inputs=[m.ntree, tree_tree, d.tree_island, stack_scratch], + outputs=[d.nisland, d.tree_island, stack_scratch], ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/math.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/math.py index ec49041e58..3ce2d1c5d3 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/math.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/math.py @@ -83,6 +83,35 @@ def quat_to_mat(quat: wp.quat) -> wp.mat33: ) +@wp.func +def quat_z2vec(vec: wp.vec3) -> wp.quat: + """Compute quaternion performing rotation from z-axis to given vector.""" + quat = wp.quat(0.0, 0.0, 0.0, 1.0) + + # normalize vector; if too small, no rotation + norm = wp.length(vec) + if norm < types.MJ_MINVAL: + return quat + vec = vec / norm + + axis = wp.vec3(-vec[1], vec[0], 0.0) + a = wp.length(axis) + + # almost parallel + if a < types.MJ_MINVAL: + # opposite: 180 deg rotation around x axis + if vec[2] < 0.0: + quat = wp.quat(1.0, 0.0, 0.0, 0.0) + return quat + + # make quaternion from angle and axis + axis = axis / a + angle = wp.atan2(a, vec[2]) + quat = axis_angle_to_quat(axis, angle) + + return quat + + @wp.func def quat_inv(quat: wp.quat) -> wp.quat: return wp.quat(quat[0], -quat[1], -quat[2], -quat[3]) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/passive.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/passive.py index 0bcd13af52..4abcff26b1 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/passive.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/passive.py @@ -89,8 +89,8 @@ def _spring_damper_dof_passive( stiffness = jnt_stiffness[worldid % jnt_stiffness.shape[0], jntid] damping = dof_damping[worldid % dof_damping.shape[0], dofid] - has_stiffness = stiffness != 0.0 and not opt_disableflags & DisableBit.SPRING - has_damping = damping != 0.0 and not opt_disableflags & DisableBit.DAMPER + has_stiffness = stiffness != 0.0 and not (opt_disableflags & DisableBit.SPRING) + has_damping = damping != 0.0 and not (opt_disableflags & DisableBit.DAMPER) if not has_stiffness: qfrc_spring_out[worldid, dofid] = 0.0 @@ -182,11 +182,14 @@ def _spring_damper_dof_passive( @wp.kernel def _spring_damper_tendon_passive( # Model: + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), tendon_stiffness: wp.array2d(dtype=float), tendon_damping: wp.array2d(dtype=float), tendon_lengthspring: wp.array2d(dtype=wp.vec2), # Data in: - ten_J_in: wp.array3d(dtype=float), + ten_J_in: wp.array2d(dtype=float), ten_length_in: wp.array2d(dtype=float), ten_velocity_in: wp.array2d(dtype=float), # In: @@ -196,7 +199,7 @@ def _spring_damper_tendon_passive( qfrc_spring_out: wp.array2d(dtype=float), qfrc_damper_out: wp.array2d(dtype=float), ): - worldid, tenid, dofid = wp.tid() + worldid, tenid, dofid_sparse = wp.tid() stiffness = tendon_stiffness[worldid % tendon_stiffness.shape[0], tenid] damping = tendon_damping[worldid % tendon_damping.shape[0], tenid] @@ -207,7 +210,13 @@ def _spring_damper_tendon_passive( if not has_stiffness and not has_damping: return - J = ten_J_in[worldid, tenid, dofid] + rownnz = ten_J_rownnz[tenid] + if dofid_sparse >= rownnz: + return + rowadr = ten_J_rowadr[tenid] + sparseid = rowadr + dofid_sparse + J = ten_J_in[worldid, sparseid] + dofid = ten_J_colind[sparseid] if has_stiffness: # compute spring force along tendon @@ -265,28 +274,28 @@ def _gravity_force( @wp.kernel def _fluid_force( - # Model: - opt_wind: wp.array(dtype=wp.vec3), - opt_density: wp.array(dtype=float), - opt_viscosity: wp.array(dtype=float), - body_rootid: wp.array(dtype=int), - body_geomnum: wp.array(dtype=int), - body_geomadr: wp.array(dtype=int), - body_mass: wp.array2d(dtype=float), - body_inertia: wp.array2d(dtype=wp.vec3), - geom_type: wp.array(dtype=int), - geom_size: wp.array2d(dtype=wp.vec3), - geom_fluid: wp.array2d(dtype=float), - body_fluid_ellipsoid: wp.array(dtype=bool), - # Data in: - xipos_in: wp.array2d(dtype=wp.vec3), - ximat_in: wp.array2d(dtype=wp.mat33), - geom_xpos_in: wp.array2d(dtype=wp.vec3), - geom_xmat_in: wp.array2d(dtype=wp.mat33), - subtree_com_in: wp.array2d(dtype=wp.vec3), - cvel_in: wp.array2d(dtype=wp.spatial_vector), - # Out: - fluid_applied_out: wp.array2d(dtype=wp.spatial_vector), + # Model: + opt_wind: wp.array(dtype=wp.vec3), + opt_density: wp.array(dtype=float), + opt_viscosity: wp.array(dtype=float), + body_rootid: wp.array(dtype=int), + body_geomnum: wp.array(dtype=int), + body_geomadr: wp.array(dtype=int), + body_mass: wp.array2d(dtype=float), + body_inertia: wp.array2d(dtype=wp.vec3), + geom_type: wp.array(dtype=int), + geom_size: wp.array2d(dtype=wp.vec3), + geom_fluid: wp.array2d(dtype=float), + body_fluid_ellipsoid: wp.array(dtype=bool), + # Data in: + xipos_in: wp.array2d(dtype=wp.vec3), + ximat_in: wp.array2d(dtype=wp.mat33), + geom_xpos_in: wp.array2d(dtype=wp.vec3), + geom_xmat_in: wp.array2d(dtype=wp.mat33), + subtree_com_in: wp.array2d(dtype=wp.vec3), + cvel_in: wp.array2d(dtype=wp.spatial_vector), + # Out: + fluid_applied_out: wp.array2d(dtype=wp.spatial_vector), ): """Computes body-space fluid forces for both inertia-box and ellipsoid models.""" worldid, bodyid = wp.tid() @@ -495,29 +504,29 @@ def _fluid(m: Model, d: Data): fluid_applied = wp.empty((d.nworld, m.nbody), dtype=wp.spatial_vector) wp.launch( - _fluid_force, - dim=(d.nworld, m.nbody), - inputs=[ - m.opt.wind, - m.opt.density, - m.opt.viscosity, - m.body_rootid, - m.body_geomnum, - m.body_geomadr, - m.body_mass, - m.body_inertia, - m.geom_type, - m.geom_size, - m.geom_fluid, - m.body_fluid_ellipsoid, - d.xipos, - d.ximat, - d.geom_xpos, - d.geom_xmat, - d.subtree_com, - d.cvel, - ], - outputs=[fluid_applied], + _fluid_force, + dim=(d.nworld, m.nbody), + inputs=[ + m.opt.wind, + m.opt.density, + m.opt.viscosity, + m.body_rootid, + m.body_geomnum, + m.body_geomadr, + m.body_mass, + m.body_inertia, + m.geom_type, + m.geom_size, + m.geom_fluid, + m.body_fluid_ellipsoid, + d.xipos, + d.ximat, + d.geom_xpos, + d.geom_xmat, + d.subtree_com, + d.cvel, + ], + outputs=[fluid_applied], ) support.apply_ft(m, d, fluid_applied, d.qfrc_fluid, False) @@ -565,6 +574,7 @@ def _flex_elasticity( flex_edgeadr: wp.array(dtype=int), flex_elemadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_elemedgeadr: wp.array(dtype=int), flex_vertbodyid: wp.array(dtype=int), flex_elem: wp.array(dtype=int), @@ -590,32 +600,39 @@ def _flex_elasticity( f = i break + local_elemid = elemid - flex_elemadr[f] dim = flex_dim[f] nvert = dim + 1 nedge = nvert * (nvert - 1) / 2 edges = wp.where( - dim == 3, - wp.matrix(0, 1, 1, 2, 2, 0, 2, 3, 0, 3, 1, 3, shape=(6, 2), dtype=int), - wp.matrix(1, 2, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, shape=(6, 2), dtype=int), + dim == 1, + wp.matrix(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, shape=(6, 2), dtype=int), + wp.where( + dim == 3, + wp.matrix(0, 1, 1, 2, 2, 0, 2, 3, 0, 3, 1, 3, shape=(6, 2), dtype=int), + wp.matrix(1, 2, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, shape=(6, 2), dtype=int), + ), ) if timestep > 0.0 and not dsbl_damper: kD = flex_damping[f] / timestep else: kD = 0.0 + elem_data_adr = flex_elemdataadr[f] + local_elemid * (dim + 1) + vbase = flex_vertadr[f] gradient = wp.matrix(0.0, shape=(6, 6)) for e in range(nedge): - vert0 = flex_elem[(dim + 1) * elemid + edges[e, 0]] - vert1 = flex_elem[(dim + 1) * elemid + edges[e, 1]] - xpos0 = flexvert_xpos_in[worldid, vert0] - xpos1 = flexvert_xpos_in[worldid, vert1] + vert0 = flex_elem[elem_data_adr + edges[e, 0]] + vert1 = flex_elem[elem_data_adr + edges[e, 1]] + xpos0 = flexvert_xpos_in[worldid, vbase + vert0] + xpos1 = flexvert_xpos_in[worldid, vbase + vert1] for i in range(3): gradient[e, 0 + i] = xpos0[i] - xpos1[i] gradient[e, 3 + i] = xpos1[i] - xpos0[i] elongation = wp.spatial_vectorf(0.0) for e in range(nedge): - idx = flex_elemedge[elemid * nedge + e] + idx = flex_elemedge[flex_elemedgeadr[f] + local_elemid * nedge + e] vel = flexedge_velocity_in[worldid, flex_edgeadr[f] + idx] deformed = flexedge_length_in[worldid, flex_edgeadr[f] + idx] reference = flexedge_length0[flex_edgeadr[f] + idx] @@ -638,7 +655,7 @@ def _flex_elasticity( force[edges[ed2, i], x] -= elongation[ed1] * gradient[ed2, 3 * i + x] * metric[ed1, ed2] for v in range(nvert): - vert = flex_elem[(dim + 1) * elemid + v] + vert = flex_elem[elem_data_adr + v] bodyid = flex_vertbodyid[flex_vertadr[f] + vert] for x in range(3): wp.atomic_add(qfrc_spring_out, worldid, body_dofadr[bodyid] + x, force[v, x]) @@ -742,8 +759,11 @@ def passive(m: Model, d: Data): if m.ntendon: wp.launch( _spring_damper_tendon_passive, - dim=(d.nworld, m.ntendon, m.nv), + dim=(d.nworld, m.ntendon, m.max_ten_J_rownnz), inputs=[ + m.ten_J_rownnz, + m.ten_J_rowadr, + m.ten_J_colind, m.tendon_stiffness, m.tendon_damping, m.tendon_lengthspring, @@ -772,6 +792,7 @@ def passive(m: Model, d: Data): m.flex_edgeadr, m.flex_elemadr, m.flex_elemnum, + m.flex_elemdataadr, m.flex_elemedgeadr, m.flex_vertbodyid, m.flex_elem, diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/ray.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/ray.py index 5eaea82124..57d6a5ba74 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/ray.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/ray.py @@ -752,7 +752,8 @@ def ray_mesh_with_bvh_anyhit( @wp.func def ray_flex_with_bvh( # In: - bvh_id: wp.uint64, + flex_bvh_id: wp.array(dtype=wp.uint64), + flexid: int, group_root: int, pnt: wp.vec3, vec: wp.vec3, @@ -769,7 +770,7 @@ def ray_flex_with_bvh( n = wp.vec3(0.0, 0.0, 0.0) f = int(-1) - hit = wp.mesh_query_ray(bvh_id, pnt, vec, max_t, t, u, v, sign, n, f, group_root) + hit = wp.mesh_query_ray(flex_bvh_id[flexid], pnt, vec, max_t, t, u, v, sign, n, f, group_root) if hit: return t, n, u, v, f @@ -777,6 +778,23 @@ def ray_flex_with_bvh( return -1.0, wp.vec3(0.0, 0.0, 0.0), 0.0, 0.0, -1 +@wp.func +def ray_flex_with_bvh_anyhit( + # In: + flex_bvh_id: wp.array(dtype=wp.uint64), + flexid: int, + group_root: int, + pnt: wp.vec3, + vec: wp.vec3, + max_t: float, +) -> bool: + """Returns True if there is any hit for ray flex intersections. + + Requires wp.Mesh be constructed and their ids to be passed. Flex are already in world space. + """ + return wp.mesh_query_ray_anyhit(flex_bvh_id[flexid], pnt, vec, max_t, group_root) + + @wp.func def ray_geom(pos: wp.vec3, mat: wp.mat33, size: wp.vec3, pnt: wp.vec3, vec: wp.vec3, geomtype: int) -> Tuple[float, wp.vec3]: """Returns distance along ray to intersection with geom and normal at intersection point. diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/render.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/render.py index 28a4284f2d..bc8d16c3ad 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/render.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/render.py @@ -23,6 +23,7 @@ from mujoco.mjx.third_party.mujoco_warp._src.ray import ray_cylinder from mujoco.mjx.third_party.mujoco_warp._src.ray import ray_ellipsoid from mujoco.mjx.third_party.mujoco_warp._src.ray import ray_flex_with_bvh +from mujoco.mjx.third_party.mujoco_warp._src.ray import ray_flex_with_bvh_anyhit from mujoco.mjx.third_party.mujoco_warp._src.ray import ray_mesh_with_bvh from mujoco.mjx.third_party.mujoco_warp._src.ray import ray_mesh_with_bvh_anyhit from mujoco.mjx.third_party.mujoco_warp._src.ray import ray_plane @@ -39,10 +40,6 @@ wp.set_module_options({"enable_backward": False}) -# TODO(team): remove after mjwarp depends on warp-lang >= 1.12 in pyproject.toml -from mujoco.mjx.third_party.mujoco_warp._src.types import TEXTURE_DTYPE - - @wp.func def sample_texture( # Model: @@ -51,7 +48,7 @@ def sample_texture( # In: geom_id: int, tex_repeat: wp.vec2, - tex: TEXTURE_DTYPE, + tex: wp.Texture2D, pos: wp.vec3, rot: wp.mat33, mesh_facetexcoord: wp.array(dtype=wp.vec3i), @@ -94,17 +91,26 @@ def cast_ray( geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: bvh_id: wp.uint64, group_root: int, - world_id: int, + worldid: int, bvh_ngeom: int, + flex_bvh_ngeom: int, enabled_geom_ids: wp.array(dtype=int), mesh_bvh_id: wp.array(dtype=wp.uint64), hfield_bvh_id: wp.array(dtype=wp.uint64), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), ray_origin_world: wp.vec3, ray_dir_world: wp.vec3, ) -> Tuple[int, float, wp.vec3, float, float, int, int]: @@ -118,91 +124,127 @@ def cast_ray( query = wp.bvh_query_ray(bvh_id, ray_origin_world, ray_dir_world, group_root) bounds_nr = int(0) + ngeom = bvh_ngeom + flex_bvh_ngeom while wp.bvh_query_next(query, bounds_nr, dist): gi_global = bounds_nr - gi_bvh_local = gi_global - (world_id * bvh_ngeom) - gi = enabled_geom_ids[gi_bvh_local] + local_id = gi_global - (worldid * ngeom) + d = float(-1.0) hit_mesh_id = int(-1) u = float(0.0) v = float(0.0) f = int(-1) n = wp.vec3(0.0, 0.0, 0.0) + hit_geom_id = int(-1) + + if local_id < bvh_ngeom: + gi = enabled_geom_ids[local_id] + gtype = geom_type[gi] + else: + gi = local_id - bvh_ngeom + gtype = GeomType.FLEX + + hit_geom_id = gi # TODO: Investigate branch elimination with static loop unrolling - if geom_type[gi] == GeomType.PLANE: + if gtype == GeomType.PLANE: d, n = ray_plane( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.HFIELD: + if gtype == GeomType.HFIELD: d, n, u, v, f, geom_hfield_id = ray_mesh_with_bvh( hfield_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, dist, ) - if geom_type[gi] == GeomType.SPHERE: + if gtype == GeomType.SPHERE: d, n = ray_sphere( - geom_xpos_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi][0] * geom_size[world_id % geom_size.shape[0], gi][0], + geom_xpos_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi][0] * geom_size[worldid % geom_size.shape[0], gi][0], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.ELLIPSOID: + if gtype == GeomType.ELLIPSOID: d, n = ray_ellipsoid( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CAPSULE: + if gtype == GeomType.CAPSULE: d, n = ray_capsule( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CYLINDER: + if gtype == GeomType.CYLINDER: d, n = ray_cylinder( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.BOX: + if gtype == GeomType.BOX: d, all, n = ray_box( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.MESH: + if gtype == GeomType.MESH: d, n, u, v, f, hit_mesh_id = ray_mesh_with_bvh( mesh_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, dist, ) + if gtype == GeomType.FLEX: + hit_geom_id = -2 + flexid = flex_geom_flexid[gi] + edge_id = flex_geom_edgeid[gi] + + if edge_id >= 0: + edge = flex_edge[edge_id] + vert_adr = flex_vertadr[flexid] + v0 = flexvert_xpos_in[worldid, vert_adr + edge[0]] + v1 = flexvert_xpos_in[worldid, vert_adr + edge[1]] + pos = 0.5 * (v0 + v1) + vec = v1 - v0 + + length = wp.length(vec) + edgeq = math.quat_z2vec(vec) + mat = math.quat_to_mat(edgeq) + size = wp.vec3(flex_radius[flexid], 0.5 * length, 0.0) + + d, n = ray_capsule(pos, mat, size, ray_origin_world, ray_dir_world) + hit_mesh_id = flexid + else: + flex_gr = flex_group_root[worldid, flexid] + d, n, u, v, f = ray_flex_with_bvh(flex_bvh_id, flexid, flex_gr, ray_origin_world, ray_dir_world, dist) + if d >= 0.0: + hit_mesh_id = flexid if d >= 0.0 and d < dist: dist = d normal = n - geom_id = gi + geom_id = hit_geom_id bary_u = u bary_v = v face_idx = f @@ -217,17 +259,26 @@ def cast_ray_first_hit( geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: bvh_id: wp.uint64, group_root: int, - world_id: int, + worldid: int, bvh_ngeom: int, + bvh_nflexgeom: int, enabled_geom_ids: wp.array(dtype=int), mesh_bvh_id: wp.array(dtype=wp.uint64), hfield_bvh_id: wp.array(dtype=wp.uint64), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), ray_origin_world: wp.vec3, ray_dir_world: wp.vec3, max_dist: float, @@ -235,81 +286,119 @@ def cast_ray_first_hit( """A simpler version of casting rays that only checks for the first hit.""" query = wp.bvh_query_ray(bvh_id, ray_origin_world, ray_dir_world, group_root) bounds_nr = int(0) + ngeom = bvh_ngeom + bvh_nflexgeom while wp.bvh_query_next(query, bounds_nr, max_dist): gi_global = bounds_nr - gi_bvh_local = gi_global - (world_id * bvh_ngeom) - gi = enabled_geom_ids[gi_bvh_local] + local_id = gi_global - (worldid * ngeom) + + d = float(-1.0) + n = wp.vec3(0.0, 0.0, 0.0) + + if local_id < bvh_ngeom: + gi = enabled_geom_ids[local_id] + gtype = geom_type[gi] + else: + gi = local_id - bvh_ngeom + gtype = GeomType.FLEX # TODO: Investigate branch elimination with static loop unrolling - if geom_type[gi] == GeomType.PLANE: + if gtype == GeomType.PLANE: d, n = ray_plane( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.HFIELD: + if gtype == GeomType.HFIELD: d, n, u, v, f, geom_hfield_id = ray_mesh_with_bvh( hfield_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, max_dist, ) - if geom_type[gi] == GeomType.SPHERE: + if gtype == GeomType.SPHERE: d, n = ray_sphere( - geom_xpos_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi][0] * geom_size[world_id % geom_size.shape[0], gi][0], + geom_xpos_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi][0] * geom_size[worldid % geom_size.shape[0], gi][0], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.ELLIPSOID: + if gtype == GeomType.ELLIPSOID: d, n = ray_ellipsoid( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CAPSULE: + if gtype == GeomType.CAPSULE: d, n = ray_capsule( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CYLINDER: + if gtype == GeomType.CYLINDER: d, n = ray_cylinder( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.BOX: + if gtype == GeomType.BOX: d, all, n = ray_box( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.MESH: + if gtype == GeomType.MESH: hit = ray_mesh_with_bvh_anyhit( mesh_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, max_dist, ) d = 0.0 if hit else -1.0 + if gtype == GeomType.FLEX: + flexid = flex_geom_flexid[gi] + edge_id = flex_geom_edgeid[gi] + + if edge_id >= 0: + edge = flex_edge[edge_id] + vert_adr = flex_vertadr[flexid] + v0 = flexvert_xpos_in[worldid, vert_adr + edge[0]] + v1 = flexvert_xpos_in[worldid, vert_adr + edge[1]] + pos = 0.5 * (v0 + v1) + vec = v1 - v0 + + length = wp.length(vec) + edgeq = math.quat_z2vec(vec) + mat = math.quat_to_mat(edgeq) + size = wp.vec3(flex_radius[flexid], 0.5 * length, 0.0) + + d, n = ray_capsule(pos, mat, size, ray_origin_world, ray_dir_world) + else: + hit = ray_flex_with_bvh_anyhit( + flex_bvh_id, + flexid, + flex_group_root[worldid, flexid], + ray_origin_world, + ray_dir_world, + max_dist, + ) + d = 0.0 if hit else -1.0 if d >= 0.0 and d < max_dist: return True @@ -323,18 +412,27 @@ def compute_lighting( geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: use_shadows: bool, bvh_id: wp.uint64, group_root: int, bvh_ngeom: int, + bvh_nflexgeom: int, enabled_geom_ids: wp.array(dtype=int), - world_id: int, + worldid: int, mesh_bvh_id: wp.array(dtype=wp.uint64), hfield_bvh_id: wp.array(dtype=wp.uint64), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), lightactive: bool, lighttype: int, lightcastshadow: bool, @@ -385,15 +483,24 @@ def compute_lighting( geom_type, geom_dataid, geom_size, + flex_vertadr, + flex_edge, + flex_radius, geom_xpos_in, geom_xmat_in, + flexvert_xpos_in, bvh_id, group_root, - world_id, + worldid, bvh_ngeom, + bvh_nflexgeom, enabled_geom_ids, mesh_bvh_id, hfield_bvh_id, + flex_geom_flexid, + flex_geom_edgeid, + flex_bvh_id, + flex_group_root, shadow_origin, L, max_t, @@ -418,6 +525,7 @@ def render(m: Model, d: Data, rc: RenderContext): """ rc.rgb_data.fill_(rc.background_color) rc.depth_data.fill_(0.0) + rc.seg_data.fill_(-1) @wp.kernel(module="unique", enable_backward=False) def _render_megakernel( @@ -434,6 +542,9 @@ def _render_megakernel( light_type: wp.array2d(dtype=int), light_castshadow: wp.array2d(dtype=bool), light_active: wp.array2d(dtype=bool), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), mesh_faceadr: wp.array(dtype=int), mat_texid: wp.array3d(dtype=int), mat_texrepeat: wp.array2d(dtype=wp.vec2), @@ -445,21 +556,25 @@ def _render_megakernel( cam_xmat_in: wp.array2d(dtype=wp.mat33), light_xpos_in: wp.array2d(dtype=wp.vec3), light_xdir_in: wp.array2d(dtype=wp.vec3), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: nrender: int, use_shadows: bool, bvh_ngeom: int, + bvh_nflexgeom: int, cam_res: wp.array(dtype=wp.vec2i), cam_id_map: wp.array(dtype=int), ray: wp.array(dtype=wp.vec3), rgb_adr: wp.array(dtype=int), depth_adr: wp.array(dtype=int), + seg_adr: wp.array(dtype=int), render_rgb: wp.array(dtype=bool), render_depth: wp.array(dtype=bool), + render_seg: wp.array(dtype=bool), bvh_id: wp.uint64, group_root: wp.array(dtype=int), - flex_bvh_id: wp.uint64, - flex_group_root: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), enabled_geom_ids: wp.array(dtype=int), mesh_bvh_id: wp.array(dtype=wp.uint64), mesh_facetexcoord: wp.array(dtype=wp.vec3i), @@ -467,46 +582,48 @@ def _render_megakernel( mesh_texcoord_offsets: wp.array(dtype=int), hfield_bvh_id: wp.array(dtype=wp.uint64), flex_rgba: wp.array(dtype=wp.vec4), - # TODO: remove after mjwarp depends on warp-lang >= 1.12 in pyproject.toml - textures: wp.array(dtype=TEXTURE_DTYPE), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + textures: wp.array(dtype=wp.Texture2D), # Out: rgb_out: wp.array2d(dtype=wp.uint32), depth_out: wp.array2d(dtype=float), + seg_out: wp.array2d(dtype=int), ): - world_idx, ray_idx = wp.tid() + worldid, rayid = wp.tid() - # Map global ray_idx -> (cam_idx, ray_idx_local) using cumulative sizes + # Map global rayid -> (cam_idx, rayid_local) using cumulative sizes cam_idx = int(-1) - ray_idx_local = int(-1) + rayid_local = int(-1) accum = int(0) for i in range(nrender): num_i = cam_res[i][0] * cam_res[i][1] - if ray_idx < accum + num_i: + if rayid < accum + num_i: cam_idx = i - ray_idx_local = ray_idx - accum + rayid_local = rayid - accum break accum += num_i - if cam_idx == -1 or ray_idx_local < 0: + if cam_idx == -1 or rayid_local < 0: return - if not render_rgb[cam_idx] and not render_depth[cam_idx]: + if not render_rgb[cam_idx] and not render_depth[cam_idx] and not render_seg[cam_idx]: return # Map active camera index to MuJoCo camera ID mujoco_cam_id = cam_id_map[cam_idx] if wp.static(rc.use_precomputed_rays): - ray_dir_local_cam = ray[ray_idx] + ray_dir_local_cam = ray[rayid] else: img_w = cam_res[cam_idx][0] img_h = cam_res[cam_idx][1] - px = ray_idx_local % img_w - py = ray_idx_local // img_w + px = rayid_local % img_w + py = rayid_local // img_w ray_dir_local_cam = compute_ray( cam_projection[mujoco_cam_id], - cam_fovy[world_idx % cam_fovy.shape[0], mujoco_cam_id], + cam_fovy[worldid % cam_fovy.shape[0], mujoco_cam_id], cam_sensorsize[mujoco_cam_id], - cam_intrinsic[world_idx % cam_intrinsic.shape[0], mujoco_cam_id], + cam_intrinsic[worldid % cam_intrinsic.shape[0], mujoco_cam_id], img_w, img_h, px, @@ -514,38 +631,37 @@ def _render_megakernel( wp.static(rc.znear), ) - ray_dir_world = cam_xmat_in[world_idx, mujoco_cam_id] @ ray_dir_local_cam - ray_origin_world = cam_xpos_in[world_idx, mujoco_cam_id] + ray_dir_world = cam_xmat_in[worldid, mujoco_cam_id] @ ray_dir_local_cam + ray_origin_world = cam_xpos_in[worldid, mujoco_cam_id] geom_id, dist, normal, u, v, f, mesh_id = cast_ray( geom_type, geom_dataid, geom_size, + flex_vertadr, + flex_edge, + flex_radius, geom_xpos_in, geom_xmat_in, + flexvert_xpos_in, bvh_id, - group_root[world_idx], - world_idx, + group_root[worldid], + worldid, bvh_ngeom, + bvh_nflexgeom, enabled_geom_ids, mesh_bvh_id, hfield_bvh_id, + flex_geom_flexid, + flex_geom_edgeid, + flex_bvh_id, + flex_group_root, ray_origin_world, ray_dir_world, ) - if wp.static(m.nflex > 0): - d, n, u, v, f = ray_flex_with_bvh( - flex_bvh_id, - flex_group_root[world_idx], - ray_origin_world, - ray_dir_world, - dist, - ) - if d >= 0.0 and d < dist: - dist = d - normal = n - geom_id = -2 + if render_seg[cam_idx] and geom_id != -1: + seg_out[worldid, seg_adr[cam_idx] + rayid_local] = geom_id # Early Out if geom_id == -1: @@ -556,9 +672,7 @@ def _render_megakernel( # In camera-local coordinates, the optical axis is -Z. The Z-component of the # normalized ray direction is negative, so -ray_dir_local_cam[2] gives cos(θ) # between the ray and the optical axis. - depth_out[world_idx, depth_adr[cam_idx] + ray_idx_local] = dist * ( - -ray_dir_local_cam[2] - ) + depth_out[worldid, depth_adr[cam_idx] + rayid_local] = dist * (-ray_dir_local_cam[2]) if not render_rgb[cam_idx]: return @@ -567,31 +681,30 @@ def _render_megakernel( hit_point = ray_origin_world + ray_dir_world * dist if geom_id == -2: - # TODO: Currently flex textures are not supported, and only the first rgba value - # is used until further flex support is added. - color = flex_rgba[0] - elif geom_matid[world_idx % geom_matid.shape[0], geom_id] == -1: - color = geom_rgba[world_idx % geom_rgba.shape[0], geom_id] + # We encode flex_id in mesh_id for flex ray hits during cast_ray + color = flex_rgba[mesh_id] + elif geom_matid[worldid % geom_matid.shape[0], geom_id] == -1: + color = geom_rgba[worldid % geom_rgba.shape[0], geom_id] else: - color = mat_rgba[world_idx % mat_rgba.shape[0], geom_matid[world_idx % geom_matid.shape[0], geom_id]] + color = mat_rgba[worldid % mat_rgba.shape[0], geom_matid[worldid % geom_matid.shape[0], geom_id]] base_color = wp.vec3(color[0], color[1], color[2]) hit_color = base_color if wp.static(rc.use_textures): if geom_id != -2: - mat_id = geom_matid[world_idx % geom_matid.shape[0], geom_id] + mat_id = geom_matid[worldid % geom_matid.shape[0], geom_id] if mat_id >= 0: - tex_id = mat_texid[world_idx % mat_texid.shape[0], mat_id, 1] + tex_id = mat_texid[worldid % mat_texid.shape[0], mat_id, 1] if tex_id >= 0: tex_color = sample_texture( geom_type, mesh_faceadr, geom_id, - mat_texrepeat[world_idx % mat_texrepeat.shape[0], mat_id], + mat_texrepeat[worldid % mat_texrepeat.shape[0], mat_id], textures[tex_id], - geom_xpos_in[world_idx, geom_id], - geom_xmat_in[world_idx, geom_id], + geom_xpos_in[worldid, geom_id], + geom_xmat_in[worldid, geom_id], mesh_facetexcoord, mesh_texcoord, mesh_texcoord_offsets, @@ -616,21 +729,30 @@ def _render_megakernel( geom_type, geom_dataid, geom_size, + flex_vertadr, + flex_edge, + flex_radius, geom_xpos_in, geom_xmat_in, + flexvert_xpos_in, use_shadows, bvh_id, - group_root[world_idx], + group_root[worldid], bvh_ngeom, + bvh_nflexgeom, enabled_geom_ids, - world_idx, + worldid, mesh_bvh_id, hfield_bvh_id, - light_active[world_idx % light_active.shape[0], l], - light_type[world_idx % light_type.shape[0], l], - light_castshadow[world_idx % light_castshadow.shape[0], l], - light_xpos_in[world_idx, l], - light_xdir_in[world_idx, l], + flex_geom_flexid, + flex_geom_edgeid, + flex_bvh_id, + flex_group_root, + light_active[worldid % light_active.shape[0], l], + light_type[worldid % light_type.shape[0], l], + light_castshadow[worldid % light_castshadow.shape[0], l], + light_xpos_in[worldid, l], + light_xdir_in[worldid, l], normal, hit_point, ) @@ -639,7 +761,7 @@ def _render_megakernel( hit_color = wp.min(result, wp.vec3(1.0, 1.0, 1.0)) hit_color = wp.max(hit_color, wp.vec3(0.0, 0.0, 0.0)) - rgb_out[world_idx, rgb_adr[cam_idx] + ray_idx_local] = pack_rgba_to_uint32( + rgb_out[worldid, rgb_adr[cam_idx] + rayid_local] = pack_rgba_to_uint32( hit_color[0] * 255.0, hit_color[1] * 255.0, hit_color[2] * 255.0, @@ -662,6 +784,9 @@ def _render_megakernel( m.light_type, m.light_castshadow, m.light_active, + m.flex_vertadr, + m.flex_edge, + m.flex_radius, m.mesh_faceadr, m.mat_texid, m.mat_texrepeat, @@ -672,16 +797,20 @@ def _render_megakernel( d.cam_xmat, d.light_xpos, d.light_xdir, + d.flexvert_xpos, rc.nrender, rc.use_shadows, rc.bvh_ngeom, + rc.bvh_nflexgeom, rc.cam_res, rc.cam_id_map, rc.ray, rc.rgb_adr, rc.depth_adr, + rc.seg_adr, rc.render_rgb, rc.render_depth, + rc.render_seg, rc.bvh_id, rc.group_root, rc.flex_bvh_id, @@ -693,10 +822,13 @@ def _render_megakernel( rc.mesh_texcoord_offsets, rc.hfield_bvh_id, rc.flex_rgba, + rc.flex_geom_flexid, + rc.flex_geom_edgeid, rc.textures, ], outputs=[ rc.rgb_data, rc.depth_data, + rc.seg_data, ], ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/render_util.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/render_util.py index ccb808ff47..36958f8e98 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/render_util.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/render_util.py @@ -206,3 +206,41 @@ def get_depth(rc: RenderContext, camera_index: int, depth_scale: float, depth_ou inputs=[rc.depth_data, rc.depth_adr, camera_index, depth_scale], outputs=[depth_out], ) + + +@wp.kernel +def _extract_seg_kernel( + # In: + seg_data: wp.array2d(dtype=int), + seg_adr: wp.array(dtype=int), + camera_index: int, + # Out: + seg_out: wp.array3d(dtype=int), +): + """Extract per-pixel geom IDs from the render context buffers for a given camera index.""" + worldid, pixelid = wp.tid() + xid = pixelid % seg_out.shape[2] + yid = pixelid // seg_out.shape[2] + + seg_adr_offset = seg_adr[camera_index] + seg_out[worldid, yid, xid] = seg_data[worldid, seg_adr_offset + pixelid] + + +def get_segmentation(rc: RenderContext, camera_index: int, seg_out: wp.array3d(dtype=int)): + """Get the segmentation data from the render context buffers for a given camera index. + + Each pixel contains the MuJoCo geom ID of the geometry hit by the ray, -1 for + background, or -2 for flex bodies. + + Args: + rc: The render context on device. + camera_index: The index of the camera to get the segmentation data for. + seg_out: The output array to store the geom IDs in, with shape + (nworld, height, width). + """ + wp.launch( + _extract_seg_kernel, + dim=(seg_out.shape[0], seg_out.shape[1] * seg_out.shape[2]), + inputs=[rc.seg_data, rc.seg_adr, camera_index], + outputs=[seg_out], + ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/sensor.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/sensor.py index 859ddf8a86..2c8177b831 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/sensor.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/sensor.py @@ -15,12 +15,17 @@ from typing import Any, Tuple +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src import math from mujoco.mjx.third_party.mujoco_warp._src import ray from mujoco.mjx.third_party.mujoco_warp._src import smooth from mujoco.mjx.third_party.mujoco_warp._src import support from mujoco.mjx.third_party.mujoco_warp._src.collision_sdf import get_sdf_params from mujoco.mjx.third_party.mujoco_warp._src.collision_sdf import sdf +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXCONPAIR +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXVAL +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINVAL from mujoco.mjx.third_party.mujoco_warp._src.types import ConeType from mujoco.mjx.third_party.mujoco_warp._src.types import ConstraintType from mujoco.mjx.third_party.mujoco_warp._src.types import ContactType @@ -28,9 +33,6 @@ from mujoco.mjx.third_party.mujoco_warp._src.types import DataType from mujoco.mjx.third_party.mujoco_warp._src.types import DisableBit from mujoco.mjx.third_party.mujoco_warp._src.types import JointType -from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXCONPAIR -from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXVAL -from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINVAL from mujoco.mjx.third_party.mujoco_warp._src.types import Model from mujoco.mjx.third_party.mujoco_warp._src.types import ObjType from mujoco.mjx.third_party.mujoco_warp._src.types import SensorType @@ -40,10 +42,10 @@ from mujoco.mjx.third_party.mujoco_warp._src.types import vec6 from mujoco.mjx.third_party.mujoco_warp._src.types import vec8 from mujoco.mjx.third_party.mujoco_warp._src.types import vec8i +from mujoco.mjx.third_party.mujoco_warp._src.types import vec_pluginattr from mujoco.mjx.third_party.mujoco_warp._src.util_misc import inside_geom from mujoco.mjx.third_party.mujoco_warp._src.warp_util import cache_kernel from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope -import warp as wp wp.set_module_options({"enable_backward": False}) @@ -2081,16 +2083,16 @@ def _transform_spatial(vec: wp.spatial_vector, dif: wp.vec3) -> wp.vec3: @wp.kernel def _preprocess_tactile_contacts( - # Model: - body_weldid: wp.array(dtype=int), - geom_bodyid: wp.array(dtype=int), - # Data in: - contact_geom_in: wp.array(dtype=wp.vec2i), - contact_worldid_in: wp.array(dtype=int), - nacon_in: wp.array(dtype=int), - # Out: - weld_geom_count_out: wp.array2d(dtype=int), - weld_geom_list_out: wp.array3d(dtype=int), + # Model: + body_weldid: wp.array(dtype=int), + geom_bodyid: wp.array(dtype=int), + # Data in: + contact_geom_in: wp.array(dtype=wp.vec2i), + contact_worldid_in: wp.array(dtype=int), + nacon_in: wp.array(dtype=int), + # Out: + weld_geom_count_out: wp.array2d(dtype=int), + weld_geom_list_out: wp.array3d(dtype=int), ): conid = wp.tid() ncon = nacon_in[0] @@ -2118,42 +2120,43 @@ def _preprocess_tactile_contacts( @wp.kernel def _sensor_tactile( - # Model: - body_rootid: wp.array(dtype=int), - body_weldid: wp.array(dtype=int), - oct_child: wp.array(dtype=vec8i), - oct_aabb: wp.array2d(dtype=wp.vec3), - oct_coeff: wp.array(dtype=vec8), - geom_type: wp.array(dtype=int), - geom_bodyid: wp.array(dtype=int), - geom_size: wp.array2d(dtype=wp.vec3), - mesh_vertadr: wp.array(dtype=int), - mesh_vertnum: wp.array(dtype=int), - mesh_octadr: wp.array(dtype=int), - mesh_normaladr: wp.array(dtype=int), - mesh_normalnum: wp.array(dtype=int), - mesh_vert: wp.array(dtype=wp.vec3), - mesh_normal: wp.array(dtype=wp.vec3), - mesh_quat: wp.array(dtype=wp.quat), - sensor_objid: wp.array(dtype=int), - sensor_refid: wp.array(dtype=int), - sensor_dim: wp.array(dtype=int), - sensor_adr: wp.array(dtype=int), - plugin: wp.array(dtype=int), - plugin_attr: wp.array(dtype=wp.vec3f), - geom_plugin_index: wp.array(dtype=int), - taxel_vertadr: wp.array(dtype=int), - taxel_sensorid: wp.array(dtype=int), - # Data in: - geom_xpos_in: wp.array2d(dtype=wp.vec3), - geom_xmat_in: wp.array2d(dtype=wp.mat33), - subtree_com_in: wp.array2d(dtype=wp.vec3), - cvel_in: wp.array2d(dtype=wp.spatial_vector), - # In: - weld_geom_count_in: wp.array2d(dtype=int), - weld_geom_list_in: wp.array3d(dtype=int), - # Data out: - sensordata_out: wp.array2d(dtype=float), + # Model: + body_rootid: wp.array(dtype=int), + body_weldid: wp.array(dtype=int), + oct_child: wp.array(dtype=vec8i), + oct_aabb: wp.array2d(dtype=wp.vec3), + oct_coeff: wp.array(dtype=vec8), + geom_type: wp.array(dtype=int), + geom_bodyid: wp.array(dtype=int), + geom_dataid: wp.array(dtype=int), + geom_size: wp.array2d(dtype=wp.vec3), + mesh_vertadr: wp.array(dtype=int), + mesh_vertnum: wp.array(dtype=int), + mesh_octadr: wp.array(dtype=int), + mesh_normaladr: wp.array(dtype=int), + mesh_normalnum: wp.array(dtype=int), + mesh_vert: wp.array(dtype=wp.vec3), + mesh_normal: wp.array(dtype=wp.vec3), + mesh_quat: wp.array(dtype=wp.quat), + sensor_objid: wp.array(dtype=int), + sensor_refid: wp.array(dtype=int), + sensor_dim: wp.array(dtype=int), + sensor_adr: wp.array(dtype=int), + plugin: wp.array(dtype=int), + plugin_attr: wp.array(dtype=vec_pluginattr), + geom_plugin_index: wp.array(dtype=int), + taxel_vertadr: wp.array(dtype=int), + taxel_sensorid: wp.array(dtype=int), + # Data in: + geom_xpos_in: wp.array2d(dtype=wp.vec3), + geom_xmat_in: wp.array2d(dtype=wp.mat33), + subtree_com_in: wp.array2d(dtype=wp.vec3), + cvel_in: wp.array2d(dtype=wp.spatial_vector), + # In: + weld_geom_count_in: wp.array2d(dtype=int), + weld_geom_list_in: wp.array3d(dtype=int), + # Data out: + sensordata_out: wp.array2d(dtype=float), ): worldid, taxelid = wp.tid() @@ -2211,40 +2214,25 @@ def _sensor_tactile( contact_type = geom_type[geom] plugin_attributes, plugin_index, volume_data, mesh_data = get_sdf_params( - oct_child, - oct_aabb, - oct_coeff, - mesh_octadr, - plugin, - plugin_attr, - contact_type, - geom_size[worldid % geom_size.shape[0], geom], - plugin_id, - mesh_id, + oct_child, + oct_aabb, + oct_coeff, + mesh_octadr, + plugin, + plugin_attr, + contact_type, + geom_size[worldid % geom_size.shape[0], geom], + plugin_id, + geom_dataid[geom], ) - depth = wp.min( - sdf( - contact_type, - lpos, - plugin_attributes, - plugin_index, - volume_data, - mesh_data, - ), - 0.0, - ) + depth = wp.min(sdf(contact_type, lpos, plugin_attributes, plugin_index, volume_data, mesh_data), 0.0) if depth >= 0.0: continue - vel_sensor = _transform_spatial( - cvel_in[worldid, parent_weld], - xpos - subtree_com_in[worldid, body_rootid[parent_weld]], - ) + vel_sensor = _transform_spatial(cvel_in[worldid, parent_weld], xpos - subtree_com_in[worldid, body_rootid[parent_weld]]) vel_other = _transform_spatial( - cvel_in[worldid, body], - geom_xpos_in[worldid, geom] - - subtree_com_in[worldid, body_rootid[body]], + cvel_in[worldid, body], geom_xpos_in[worldid, geom] - subtree_com_in[worldid, body_rootid[body]] ) vel_rel = vel_sensor - vel_other @@ -2259,24 +2247,9 @@ def _sensor_tactile( forceT[2] = wp.abs(wp.dot(vel_rel, tang2)) dim = sensor_dim[sensor_id] // 3 - wp.atomic_add( - sensordata_out, - worldid, - sensor_adr[sensor_id] + 0 * dim + vertid, - forceT[0], - ) - wp.atomic_add( - sensordata_out, - worldid, - sensor_adr[sensor_id] + 1 * dim + vertid, - forceT[1], - ) - wp.atomic_add( - sensordata_out, - worldid, - sensor_adr[sensor_id] + 2 * dim + vertid, - forceT[2], - ) + wp.atomic_add(sensordata_out, worldid, sensor_adr[sensor_id] + 0 * dim + vertid, forceT[0]) + wp.atomic_add(sensordata_out, worldid, sensor_adr[sensor_id] + 1 * dim + vertid, forceT[1]) + wp.atomic_add(sensordata_out, worldid, sensor_adr[sensor_id] + 2 * dim + vertid, forceT[2]) @wp.func @@ -2507,60 +2480,61 @@ def sensor_acc(m: Model, d: Data): weld_geom_count = wp.zeros((d.nworld, m.nbody), dtype=int) weld_geom_list = wp.full((d.nworld, m.nbody, MJ_MAXCONPAIR), -1, dtype=int) wp.launch( - _preprocess_tactile_contacts, - dim=d.naconmax, - inputs=[ - m.body_weldid, - m.geom_bodyid, - d.contact.geom, - d.contact.worldid, - d.nacon, - ], - outputs=[ - weld_geom_count, - weld_geom_list, - ], + _preprocess_tactile_contacts, + dim=d.naconmax, + inputs=[ + m.body_weldid, + m.geom_bodyid, + d.contact.geom, + d.contact.worldid, + d.nacon, + ], + outputs=[ + weld_geom_count, + weld_geom_list, + ], ) wp.launch( - _sensor_tactile, - dim=(d.nworld, m.nsensortaxel), - inputs=[ - m.body_rootid, - m.body_weldid, - m.oct_child, - m.oct_aabb, - m.oct_coeff, - m.geom_type, - m.geom_bodyid, - m.geom_size, - m.mesh_vertadr, - m.mesh_vertnum, - m.mesh_octadr, - m.mesh_normaladr, - m.mesh_normalnum, - m.mesh_vert, - m.mesh_normal, - m.mesh_quat, - m.sensor_objid, - m.sensor_refid, - m.sensor_dim, - m.sensor_adr, - m.plugin, - m.plugin_attr, - m.geom_plugin_index, - m.taxel_vertadr, - m.taxel_sensorid, - d.geom_xpos, - d.geom_xmat, - d.subtree_com, - d.cvel, - weld_geom_count, - weld_geom_list, - ], - outputs=[ - d.sensordata, - ], + _sensor_tactile, + dim=(d.nworld, m.nsensortaxel), + inputs=[ + m.body_rootid, + m.body_weldid, + m.oct_child, + m.oct_aabb, + m.oct_coeff, + m.geom_type, + m.geom_bodyid, + m.geom_dataid, + m.geom_size, + m.mesh_vertadr, + m.mesh_vertnum, + m.mesh_octadr, + m.mesh_normaladr, + m.mesh_normalnum, + m.mesh_vert, + m.mesh_normal, + m.mesh_quat, + m.sensor_objid, + m.sensor_refid, + m.sensor_dim, + m.sensor_adr, + m.plugin, + m.plugin_attr, + m.geom_plugin_index, + m.taxel_vertadr, + m.taxel_sensorid, + d.geom_xpos, + d.geom_xmat, + d.subtree_com, + d.cvel, + weld_geom_count, + weld_geom_list, + ], + outputs=[ + d.sensordata, + ], ) sensor_contact_nmatch = wp.empty((d.nworld, m.nsensorcontact), dtype=int) @@ -2882,12 +2856,12 @@ def energy_pos(m: Model, d: Data): wp.launch(_energy_pos_zero, dim=d.nworld, outputs=[d.energy]) # init potential energy: -sum_i(body_i.mass * dot(gravity, body_i.pos)) - if not m.opt.disableflags & DisableBit.GRAVITY: + if not (m.opt.disableflags & DisableBit.GRAVITY): wp.launch( _energy_pos_gravity, dim=(d.nworld, m.nbody - 1), inputs=[m.opt.gravity, m.body_mass, d.xipos], outputs=[d.energy] ) - if not m.opt.disableflags & DisableBit.SPRING: + if not (m.opt.disableflags & DisableBit.SPRING): # add joint-level springs wp.launch( _energy_pos_passive_joint, diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/smooth.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/smooth.py index 51b3640334..51dfadb461 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/smooth.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/smooth.py @@ -14,29 +14,29 @@ # ============================================================================== +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src import math from mujoco.mjx.third_party.mujoco_warp._src import support from mujoco.mjx.third_party.mujoco_warp._src import util_misc +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXVAL +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINVAL from mujoco.mjx.third_party.mujoco_warp._src.types import CamLightType from mujoco.mjx.third_party.mujoco_warp._src.types import ConeType from mujoco.mjx.third_party.mujoco_warp._src.types import Data from mujoco.mjx.third_party.mujoco_warp._src.types import DisableBit from mujoco.mjx.third_party.mujoco_warp._src.types import EqType from mujoco.mjx.third_party.mujoco_warp._src.types import JointType -from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MAXVAL -from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINVAL from mujoco.mjx.third_party.mujoco_warp._src.types import Model from mujoco.mjx.third_party.mujoco_warp._src.types import ObjType -from mujoco.mjx.third_party.mujoco_warp._src.types import SPARSE_CONSTRAINT_JACOBIAN from mujoco.mjx.third_party.mujoco_warp._src.types import TileSet from mujoco.mjx.third_party.mujoco_warp._src.types import TrnType +from mujoco.mjx.third_party.mujoco_warp._src.types import WrapType +from mujoco.mjx.third_party.mujoco_warp._src.types import vec5 from mujoco.mjx.third_party.mujoco_warp._src.types import vec10 from mujoco.mjx.third_party.mujoco_warp._src.types import vec11 -from mujoco.mjx.third_party.mujoco_warp._src.types import vec5 -from mujoco.mjx.third_party.mujoco_warp._src.types import WrapType from mujoco.mjx.third_party.mujoco_warp._src.warp_util import cache_kernel from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope -import warp as wp wp.set_module_options({"enable_backward": False}) @@ -227,38 +227,59 @@ def _site_local_to_global( @wp.kernel def _flex_vertices( # Model: + nflex: int, + flex_vertadr: wp.array(dtype=int), + flex_vertnum: wp.array(dtype=int), flex_vertbodyid: wp.array(dtype=int), + flex_vert: wp.array(dtype=wp.vec3), + flex_centered: wp.array(dtype=bool), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), + xmat_in: wp.array2d(dtype=wp.mat33), # Data out: flexvert_xpos_out: wp.array2d(dtype=wp.vec3), ): worldid, vertid = wp.tid() - flexvert_xpos_out[worldid, vertid] = xpos_in[worldid, flex_vertbodyid[vertid]] + + for f in range(nflex): + locid = vertid - flex_vertadr[f] + if locid >= 0 and locid < flex_vertnum[f]: + break + + bodyid = flex_vertbodyid[vertid] + xpos = xpos_in[worldid, bodyid] + + if flex_centered[f]: + flexvert_xpos_out[worldid, vertid] = xpos + else: + xmat = xmat_in[worldid, bodyid] + local_pos = flex_vert[vertid] + flexvert_xpos_out[worldid, vertid] = xmat @ local_pos + xpos @wp.kernel def _flex_edges( - # Model: - nflex: int, - body_rootid: wp.array(dtype=int), - body_dofadr: wp.array(dtype=int), - flex_vertadr: wp.array(dtype=int), - flex_edgeadr: wp.array(dtype=int), - flex_edgenum: wp.array(dtype=int), - flex_vertbodyid: wp.array(dtype=int), - flex_edge: wp.array(dtype=wp.vec2i), - flexedge_J_rowadr: wp.array(dtype=int), - flexedge_J_colind: wp.array(dtype=int), - # Data in: - qvel_in: wp.array2d(dtype=float), - subtree_com_in: wp.array2d(dtype=wp.vec3), - cdof_in: wp.array2d(dtype=wp.spatial_vector), - flexvert_xpos_in: wp.array2d(dtype=wp.vec3), - # Data out: - flexedge_J_out: wp.array2d(dtype=float), - flexedge_length_out: wp.array2d(dtype=float), - flexedge_velocity_out: wp.array2d(dtype=float), + # Model: + nflex: int, + body_rootid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), + flex_vertadr: wp.array(dtype=int), + flex_edgeadr: wp.array(dtype=int), + flex_edgenum: wp.array(dtype=int), + flex_vertbodyid: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flexedge_J_rowadr: wp.array(dtype=int), + flexedge_J_colind: wp.array(dtype=int), + # Data in: + qvel_in: wp.array2d(dtype=float), + subtree_com_in: wp.array2d(dtype=wp.vec3), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), + # Data out: + flexedge_J_out: wp.array2d(dtype=float), + flexedge_length_out: wp.array2d(dtype=float), + flexedge_velocity_out: wp.array2d(dtype=float), ): worldid, edgeid = wp.tid() for i in range(nflex): @@ -281,42 +302,56 @@ def _flex_edges( b1 = flex_vertbodyid[vbase0] b2 = flex_vertbodyid[vbase1] - dofi = body_dofadr[b1] - dofj = body_dofadr[b2] - - vel1 = wp.vec3( - qvel_in[worldid, dofi], - qvel_in[worldid, dofi + 1], - qvel_in[worldid, dofi + 2], - ) - vel2 = wp.vec3( - qvel_in[worldid, dofj], - qvel_in[worldid, dofj + 1], - qvel_in[worldid, dofj + 2], - ) - flexedge_velocity_out[worldid, edgeid] = wp.dot(vel2 - vel1, edge) + dofnum1 = body_dofnum[b1] + dofnum2 = body_dofnum[b2] + + # velocity via Jacobian: sum_k J_k * qvel_k for each body + vel = float(0.0) + if dofnum1 > 0: + dofi = body_dofadr[b1] + offset1 = pos1 - wp.vec3(subtree_com_in[worldid, body_rootid[b1]]) + for k in range(dofnum1): + cdof = cdof_in[worldid, dofi + k] + cdof_ang = wp.spatial_top(cdof) + cdof_lin = wp.spatial_bottom(cdof) + jacp1 = cdof_lin + wp.cross(cdof_ang, offset1) + vel -= wp.dot(jacp1, edge) * qvel_in[worldid, dofi + k] + if dofnum2 > 0: + dofj = body_dofadr[b2] + offset2 = pos2 - wp.vec3(subtree_com_in[worldid, body_rootid[b2]]) + for k in range(dofnum2): + cdof = cdof_in[worldid, dofj + k] + cdof_ang = wp.spatial_top(cdof) + cdof_lin = wp.spatial_bottom(cdof) + jacp2 = cdof_lin + wp.cross(cdof_ang, offset2) + vel += wp.dot(jacp2, edge) * qvel_in[worldid, dofj + k] + flexedge_velocity_out[worldid, edgeid] = vel rowadr = flexedge_J_rowadr[edgeid] - - # compute offsets once per body (avoids 12 redundant tree-ancestry walks in jac_dof) - offset1 = pos1 - wp.vec3(subtree_com_in[worldid, body_rootid[b1]]) - offset2 = pos2 - wp.vec3(subtree_com_in[worldid, body_rootid[b2]]) + nnz_offset = 0 # body1 DOFs: b1 is in subtree, b2 is not -> jacdif = 0 - jacp1 = -jacp1 - for k in range(3): - cdof = cdof_in[worldid, dofi + k] - cdof_ang = wp.spatial_top(cdof) - cdof_lin = wp.spatial_bottom(cdof) - jacp1 = cdof_lin + wp.cross(cdof_ang, offset1) - flexedge_J_out[worldid, rowadr + k] = wp.dot(-jacp1, edge) + if dofnum1 > 0: + dofi = body_dofadr[b1] + offset1 = pos1 - wp.vec3(subtree_com_in[worldid, body_rootid[b1]]) + for k in range(dofnum1): + cdof = cdof_in[worldid, dofi + k] + cdof_ang = wp.spatial_top(cdof) + cdof_lin = wp.spatial_bottom(cdof) + jacp1 = cdof_lin + wp.cross(cdof_ang, offset1) + flexedge_J_out[worldid, rowadr + nnz_offset + k] = wp.dot(-jacp1, edge) + nnz_offset += dofnum1 # body2 DOFs: b2 is in subtree, b1 is not -> jacdif = jacp2 - 0 = jacp2 - for k in range(3): - cdof = cdof_in[worldid, dofj + k] - cdof_ang = wp.spatial_top(cdof) - cdof_lin = wp.spatial_bottom(cdof) - jacp2 = cdof_lin + wp.cross(cdof_ang, offset2) - flexedge_J_out[worldid, rowadr + 3 + k] = wp.dot(jacp2, edge) + if dofnum2 > 0: + dofj = body_dofadr[b2] + offset2 = pos2 - wp.vec3(subtree_com_in[worldid, body_rootid[b2]]) + for k in range(dofnum2): + cdof = cdof_in[worldid, dofj + k] + cdof_ang = wp.spatial_top(cdof) + cdof_lin = wp.spatial_bottom(cdof) + jacp2 = cdof_lin + wp.cross(cdof_ang, offset2) + flexedge_J_out[worldid, rowadr + nnz_offset + k] = wp.dot(jacp2, edge) @event_scope @@ -382,13 +417,28 @@ def kinematics(m: Model, d: Data): @event_scope def flex(m: Model, d: Data): - wp.launch(_flex_vertices, dim=(d.nworld, m.nflexvert), inputs=[m.flex_vertbodyid, d.xpos], outputs=[d.flexvert_xpos]) + wp.launch( + _flex_vertices, + dim=(d.nworld, m.nflexvert), + inputs=[ + m.nflex, + m.flex_vertadr, + m.flex_vertnum, + m.flex_vertbodyid, + m.flex_vert, + m.flex_centered, + d.xpos, + d.xmat, + ], + outputs=[d.flexvert_xpos], + ) wp.launch( _flex_edges, dim=(d.nworld, m.nflexedge), inputs=[ m.nflex, m.body_rootid, + m.body_dofnum, m.body_dofadr, m.flex_vertadr, m.flex_edgeadr, @@ -790,9 +840,7 @@ def _qM_sparse( bodyid = dof_bodyid[dofid] # init M(i,i) with armature inertia - qM_out[worldid, 0, madr_ij] = dof_armature[ - worldid % dof_armature.shape[0], dofid - ] + qM_out[worldid, 0, madr_ij] = dof_armature[worldid % dof_armature.shape[0], dofid] # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(crb_in[worldid, bodyid], cdof_in[worldid, dofid]) @@ -869,35 +917,55 @@ def _tendon_armature( # Model: dof_parentid: wp.array(dtype=int), dof_Madr: wp.array(dtype=int), + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), tendon_armature: wp.array2d(dtype=float), is_sparse: bool, # Data in: - ten_J_in: wp.array3d(dtype=float), + ten_J_in: wp.array2d(dtype=float), # Data out: qM_out: wp.array3d(dtype=float), ): worldid, tenid, dofid = wp.tid() - if is_sparse: # is_sparse is not batched - madr_ij = dof_Madr[dofid] - armature = tendon_armature[worldid % tendon_armature.shape[0], tenid] if armature == 0.0: return - ten_Ji = ten_J_in[worldid, tenid, dofid] + rownnz = ten_J_rownnz[tenid] + if dofid >= rownnz: + return + rowadr = ten_J_rowadr[tenid] + dofid_sparse = dofid + sparseid = rowadr + dofid_sparse + dofid = ten_J_colind[sparseid] + ten_Ji = ten_J_in[worldid, sparseid] if ten_Ji == 0.0: return + if is_sparse: + madr_ij = dof_Madr[dofid] + # sparse backward pass over ancestors dofidi = dofid + ptr = dofid_sparse while dofid >= 0: - if dofid != dofidi: - ten_Jj = ten_J_in[worldid, tenid, dofid] - else: + if dofid == dofidi: ten_Jj = ten_Ji + else: + # scan pointer backward to find matching colind entry + while ptr >= 0: + sparseid = rowadr + ptr + if ten_J_colind[sparseid] <= dofid: + break + ptr -= 1 + if ptr >= 0 and ten_J_colind[sparseid] == dofid: + ten_Jj = ten_J_in[worldid, sparseid] + else: + ten_Jj = float(0.0) qMij = armature * ten_Jj * ten_Ji @@ -917,8 +985,17 @@ def tendon_armature(m: Model, d: Data): """Add tendon armature to qM.""" wp.launch( _tendon_armature, - dim=(d.nworld, m.ntendon, m.nv), - inputs=[m.dof_parentid, m.dof_Madr, m.tendon_armature, m.is_sparse, d.ten_J], + dim=(d.nworld, m.ntendon, m.max_ten_J_rownnz), + inputs=[ + m.dof_parentid, + m.dof_Madr, + m.ten_J_rownnz, + m.ten_J_rowadr, + m.ten_J_colind, + m.tendon_armature, + m.is_sparse, + d.ten_J, + ], outputs=[d.qM], ) @@ -1504,19 +1581,93 @@ def rne_postconstraint(m: Model, d: Data): _rne_cfrc_backward(m, d) +@wp.func +def _accumulate_jac_dot_chain( + # Model: + body_parentid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), + jnt_type: wp.array(dtype=int), + jnt_dofadr: wp.array(dtype=int), + dof_jntid: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + # Data in: + cdof_in: wp.array2d(dtype=wp.spatial_vector), + cvel_in: wp.array2d(dtype=wp.spatial_vector), + cdof_dot_in: wp.array2d(dtype=wp.spatial_vector), + # In: + offset: wp.vec3, + pvel_lin: wp.vec3, + dpnt: wp.vec3, + dvel: wp.vec3, + bodyid: int, + rowadr: int, + rownnz: int, + scale: float, + worldid: int, + # Out: + ten_Jdot_out: wp.array2d(dtype=float), +): + """Walk body chain from bodyid to root, accumulate Jdot contributions.""" + ptr = rownnz - 1 + bid = bodyid + while bid > 0: + bdofadr = body_dofadr[bid] + bdofnum = body_dofnum[bid] + # iterate DOFs in this body in descending order + for k_rev in range(bdofnum): + dof = bdofadr + bdofnum - 1 - k_rev + # scan pointer backward to find matching colind entry + while ptr >= 0: + sparseid = rowadr + ptr + if ten_J_colind[sparseid] <= dof: + break + ptr -= 1 + if ptr >= 0 and ten_J_colind[sparseid] == dof: + cdof = cdof_in[worldid, dof] + cdof_ang = wp.spatial_top(cdof) + cdof_lin = wp.spatial_bottom(cdof) + cdof_dot = cdof_dot_in[worldid, dof] + + # quaternion override: use cvel of DOF's body (which is bid) + dofjntid = dof_jntid[dof] + jnttype = jnt_type[dofjntid] + jntdofadr = jnt_dofadr[dofjntid] + if (jnttype == JointType.BALL) or ((jnttype == JointType.FREE) and dof >= jntdofadr + 3): + cdof_dot = math.motion_cross(cvel_in[worldid, bid], cdof) + + cdof_dot_ang = wp.spatial_top(cdof_dot) + cdof_dot_lin = wp.spatial_bottom(cdof_dot) + + # jacp_dot (from jac_dot_dof) + jacp_dot = cdof_dot_lin + wp.cross(cdof_dot_ang, offset) + wp.cross(cdof_ang, pvel_lin) + + # jacp (from jac_dof) + jacp = cdof_lin + wp.cross(cdof_ang, offset) + + # combined: dot(jacdot, dpnt) + dot(jac, dvel) + Jdot = (wp.dot(jacp_dot, dpnt) + wp.dot(jacp, dvel)) * scale + if Jdot != 0.0: + wp.atomic_add(ten_Jdot_out[worldid], sparseid, Jdot) + bid = body_parentid[bid] + + @wp.kernel def _tendon_dot( # Model: - nv: int, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), jnt_type: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), dof_jntid: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), tendon_adr: wp.array(dtype=int), tendon_num: wp.array(dtype=int), + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), tendon_armature: wp.array2d(dtype=float), wrap_type: wp.array(dtype=int), wrap_objid: wp.array(dtype=int), @@ -1528,7 +1679,7 @@ def _tendon_dot( cvel_in: wp.array2d(dtype=wp.spatial_vector), cdof_dot_in: wp.array2d(dtype=wp.spatial_vector), # Out: - ten_Jdot_out: wp.array3d(dtype=float), + ten_Jdot_out: wp.array2d(dtype=float), ): worldid, tenid = wp.tid() @@ -1565,13 +1716,11 @@ def _tendon_dot( # init sequence; assume it start with site wpnt0 = site_xpos_in[worldid, id0] - bodyid0 = site_bodyid[id0] - pos0 = site_xpos_in[worldid, id0] - cvel0 = cvel_in[worldid, bodyid0] - subtree_com0 = subtree_com_in[worldid, body_rootid[bodyid0]] - dif0 = pos0 - subtree_com0 - wvel0 = wp.spatial_bottom(cvel0) - wp.cross(dif0, wp.spatial_top(cvel0)) wbody0 = site_bodyid[id0] + cvel0 = cvel_in[worldid, wbody0] + subtree_com0 = subtree_com_in[worldid, body_rootid[wbody0]] + offset0 = wpnt0 - subtree_com0 + pvel_lin0 = wp.spatial_bottom(cvel0) - wp.cross(offset0, wp.spatial_top(cvel0)) # second object is geom: process site-geom-site if (type1 == WrapType.SPHERE) or (type1 == WrapType.CYLINDER): @@ -1582,12 +1731,10 @@ def _tendon_dot( wbody1 = site_bodyid[id1] wpnt1 = site_xpos_in[worldid, id1] - bodyid1 = site_bodyid[id1] - pos1 = site_xpos_in[worldid, id1] - cvel1 = cvel_in[worldid, bodyid1] - subtree_com1 = subtree_com_in[worldid, body_rootid[bodyid1]] - dif1 = pos1 - subtree_com1 - wvel1 = wp.spatial_bottom(cvel1) - wp.cross(dif1, wp.spatial_top(cvel1)) + cvel1 = cvel_in[worldid, wbody1] + subtree_com1 = subtree_com_in[worldid, body_rootid[wbody1]] + offset1 = wpnt1 - subtree_com1 + pvel_lin1 = wp.spatial_bottom(cvel1) - wp.cross(offset1, wp.spatial_top(cvel1)) # accumulate moments if consecutive points are in different bodies if wbody0 != wbody1: @@ -1595,6 +1742,8 @@ def _tendon_dot( dpnt, norm = math.normalize_with_norm(wpnt1 - wpnt0) # dvel = d / dt (dpnt) + wvel0 = wp.spatial_bottom(cvel0) - wp.cross(wpnt0 - subtree_com0, wp.spatial_top(cvel0)) + wvel1 = wp.spatial_bottom(cvel1) - wp.cross(wpnt1 - subtree_com1, wp.spatial_top(cvel1)) dvel = wvel1 - wvel0 dot = wp.dot(dpnt, dvel) dvel += dpnt * (-dot) @@ -1603,75 +1752,55 @@ def _tendon_dot( else: dvel = wp.vec3(0.0) - # get endpoint Jacobian time derivatives, subtract - # TODO(team): parallelize? - for i in range(nv): - jac1, _ = support.jac_dot_dof( - body_parentid, - body_rootid, - jnt_type, - jnt_dofadr, - dof_bodyid, - dof_jntid, - subtree_com_in, - cdof_in, - cvel_in, - cdof_dot_in, - wpnt0, - wbody0, - i, - worldid, - ) - jac2, _ = support.jac_dot_dof( - body_parentid, - body_rootid, - jnt_type, - jnt_dofadr, - dof_bodyid, - dof_jntid, - subtree_com_in, - cdof_in, - cvel_in, - cdof_dot_in, - wpnt1, - wbody1, - i, - worldid, - ) - jacdif = jac2 - jac1 - - # chain rule, first term: Jdot += d / dt (jac2 - jac1) * dpnt - Jdot = wp.dot(jacdif, dpnt) - - # get endpoint Jacobians, subtract - jac1, _ = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - wpnt0, - wbody0, - i, - worldid, - ) - jac2, _ = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - wpnt1, - wbody1, - i, - worldid, - ) - jacdif = jac2 - jac1 - - # chain rule, second term: Jdot += (jac2 - jac1) * d / dt (dpnt) - Jdot += wp.dot(jacdif, dvel) + rownnz = ten_J_rownnz[tenid] + rowadr = ten_J_rowadr[tenid] + inv_divisor = math.safe_div(float(1.0), divisor) - ten_Jdot_out[worldid, tenid, i] += math.safe_div(Jdot, divisor) + # body0 contributes with negative sign, body1 with positive + _accumulate_jac_dot_chain( + body_parentid, + body_dofnum, + body_dofadr, + jnt_type, + jnt_dofadr, + dof_jntid, + ten_J_colind, + cdof_in, + cvel_in, + cdof_dot_in, + offset0, + pvel_lin0, + dpnt, + dvel, + wbody0, + rowadr, + rownnz, + -inv_divisor, + worldid, + ten_Jdot_out, + ) + _accumulate_jac_dot_chain( + body_parentid, + body_dofnum, + body_dofadr, + jnt_type, + jnt_dofadr, + dof_jntid, + ten_J_colind, + cdof_in, + cvel_in, + cdof_dot_in, + offset1, + pvel_lin1, + dpnt, + dvel, + wbody1, + rowadr, + rownnz, + inv_divisor, + worldid, + ten_Jdot_out, + ) # TODO(team): j += 2 if geom wrapping j += 1 @@ -1680,33 +1809,45 @@ def _tendon_dot( @wp.kernel def _tendon_bias_coef( # Model: + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), tendon_armature: wp.array2d(dtype=float), # Data in: qvel_in: wp.array2d(dtype=float), # In: - ten_Jdot_in: wp.array3d(dtype=float), + ten_Jdot_in: wp.array2d(dtype=float), # Out: ten_bias_coef_out: wp.array2d(dtype=float), ): - worldid, tenid, dofid = wp.tid() + worldid, tenid, dofid_sparse = wp.tid() armature = tendon_armature[worldid % tendon_armature.shape[0], tenid] if armature == 0.0: return - ten_Jdot = ten_Jdot_in[worldid, tenid, dofid] + rownnz = ten_J_rownnz[tenid] + if dofid_sparse >= rownnz: + return + rowadr = ten_J_rowadr[tenid] + sparseid = rowadr + dofid_sparse + ten_Jdot = ten_Jdot_in[worldid, sparseid] if ten_Jdot == 0.0: return + dofid = ten_J_colind[sparseid] wp.atomic_add(ten_bias_coef_out[worldid], tenid, ten_Jdot * qvel_in[worldid, dofid]) @wp.kernel def _tendon_bias_qfrc( # Model: + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), tendon_armature: wp.array2d(dtype=float), # Data in: - ten_J_in: wp.array3d(dtype=float), + ten_J_in: wp.array2d(dtype=float), # In: ten_bias_coef_in: wp.array2d(dtype=float), # Out: @@ -1718,10 +1859,18 @@ def _tendon_bias_qfrc( if armature == 0.0: return - ten_J = ten_J_in[worldid, tenid, dofid] + rownnz = ten_J_rownnz[tenid] + if dofid >= rownnz: + return + rowadr = ten_J_rowadr[tenid] + sparseid = rowadr + dofid + ten_J = ten_J_in[worldid, sparseid] + if ten_J == 0.0: return + dofid = ten_J_colind[sparseid] + wp.atomic_add(qfrc_out[worldid], dofid, ten_J * armature * ten_bias_coef_in[worldid, tenid]) @@ -1735,21 +1884,24 @@ def tendon_bias(m: Model, d: Data, qfrc: wp.array2d(dtype=float)): qfrc: Force. """ # time derivative of tendon Jacobian - ten_Jdot = wp.zeros((d.nworld, m.ntendon, m.nv), dtype=float) + ten_Jdot = wp.zeros((d.nworld, m.nJten), dtype=float) wp.launch( _tendon_dot, dim=(d.nworld, m.ntendon), inputs=[ - m.nv, m.body_parentid, m.body_rootid, + m.body_dofnum, + m.body_dofadr, m.jnt_type, m.jnt_dofadr, - m.dof_bodyid, m.dof_jntid, m.site_bodyid, m.tendon_adr, m.tendon_num, + m.ten_J_rownnz, + m.ten_J_rowadr, + m.ten_J_colind, m.tendon_armature, m.wrap_type, m.wrap_objid, @@ -1767,15 +1919,15 @@ def tendon_bias(m: Model, d: Data, qfrc: wp.array2d(dtype=float)): ten_bias_coef = wp.zeros((d.nworld, m.ntendon), dtype=float) wp.launch( _tendon_bias_coef, - dim=(d.nworld, m.ntendon, m.nv), - inputs=[m.tendon_armature, d.qvel, ten_Jdot], + dim=(d.nworld, m.ntendon, m.max_ten_J_rownnz), + inputs=[m.ten_J_rownnz, m.ten_J_rowadr, m.ten_J_colind, m.tendon_armature, d.qvel, ten_Jdot], outputs=[ten_bias_coef], ) wp.launch( _tendon_bias_qfrc, - dim=(d.nworld, m.ntendon, m.nv), - inputs=[m.tendon_armature, d.ten_J, ten_bias_coef], + dim=(d.nworld, m.ntendon, m.max_ten_J_rownnz), + inputs=[m.ten_J_rownnz, m.ten_J_rowadr, m.ten_J_colind, m.tendon_armature, d.ten_J, ten_bias_coef], outputs=[qfrc], ) @@ -1888,45 +2040,44 @@ def com_vel(m: Model, d: Data): @wp.kernel def _transmission( - # Model: - nv: int, - body_parentid: wp.array(dtype=int), - body_rootid: wp.array(dtype=int), - body_weldid: wp.array(dtype=int), - body_dofnum: wp.array(dtype=int), - body_dofadr: wp.array(dtype=int), - jnt_type: wp.array(dtype=int), - jnt_qposadr: wp.array(dtype=int), - jnt_dofadr: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), - dof_parentid: wp.array(dtype=int), - site_bodyid: wp.array(dtype=int), - site_quat: wp.array2d(dtype=wp.quat), - tendon_adr: wp.array(dtype=int), - tendon_num: wp.array(dtype=int), - wrap_type: wp.array(dtype=int), - wrap_objid: wp.array(dtype=int), - actuator_trntype: wp.array(dtype=int), - actuator_trnid: wp.array(dtype=wp.vec2i), - actuator_gear: wp.array2d(dtype=wp.spatial_vector), - actuator_cranklength: wp.array2d(dtype=float), - # Data in: - qpos_in: wp.array2d(dtype=float), - xquat_in: wp.array2d(dtype=wp.quat), - site_xpos_in: wp.array2d(dtype=wp.vec3), - site_xmat_in: wp.array2d(dtype=wp.mat33), - subtree_com_in: wp.array2d(dtype=wp.vec3), - cdof_in: wp.array2d(dtype=wp.spatial_vector), - ten_J_in: wp.array3d(dtype=float), - ten_length_in: wp.array2d(dtype=float), - # In: - moment_nnz: wp.array(dtype=int), - # Data out: - actuator_length_out: wp.array2d(dtype=float), - moment_rownnz_out: wp.array2d(dtype=int), - moment_rowadr_out: wp.array2d(dtype=int), - moment_colind_out: wp.array2d(dtype=int), - actuator_moment_out: wp.array2d(dtype=float), + # Model: + nv: int, + body_parentid: wp.array(dtype=int), + body_rootid: wp.array(dtype=int), + body_weldid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), + jnt_type: wp.array(dtype=int), + jnt_qposadr: wp.array(dtype=int), + jnt_dofadr: wp.array(dtype=int), + dof_bodyid: wp.array(dtype=int), + dof_parentid: wp.array(dtype=int), + site_bodyid: wp.array(dtype=int), + site_quat: wp.array2d(dtype=wp.quat), + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + actuator_trntype: wp.array(dtype=int), + actuator_trnid: wp.array(dtype=wp.vec2i), + actuator_gear: wp.array2d(dtype=wp.spatial_vector), + actuator_cranklength: wp.array2d(dtype=float), + # Data in: + qpos_in: wp.array2d(dtype=float), + xquat_in: wp.array2d(dtype=wp.quat), + site_xpos_in: wp.array2d(dtype=wp.vec3), + site_xmat_in: wp.array2d(dtype=wp.mat33), + subtree_com_in: wp.array2d(dtype=wp.vec3), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + ten_J_in: wp.array2d(dtype=float), + ten_length_in: wp.array2d(dtype=float), + # In: + moment_nnz: wp.array(dtype=int), + # Data out: + actuator_length_out: wp.array2d(dtype=float), + moment_rownnz_out: wp.array2d(dtype=int), + moment_rowadr_out: wp.array2d(dtype=int), + moment_colind_out: wp.array2d(dtype=int), + actuator_moment_out: wp.array2d(dtype=float), ): worldid, actid = wp.tid() trntype = actuator_trntype[actid] @@ -2068,28 +2219,12 @@ def _transmission( # get Jacobians of axis(jacA) and vec(jac) jacp, jacr = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - site_xpos_idslider, - site_bodyid[idslider], - da, - worldid, + body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_xpos_idslider, site_bodyid[idslider], da, worldid ) jacS = jacp jacA = wp.cross(jacr, axis) jac, _ = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - site_xpos_id, - site_bodyid[id], - da, - worldid, + body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_xpos_id, site_bodyid[id], da, worldid ) jac -= jacS @@ -2110,38 +2245,18 @@ def _transmission( gear0 = gear[0] actuator_length_out[worldid, actid] = ten_length_in[worldid, tenid] * gear0 - # fixed - adr = tendon_adr[tenid] - if wrap_type[adr] == WrapType.JOINT: - ten_num = tendon_num[tenid] - rowadr = wp.atomic_add(moment_nnz, worldid, ten_num) - moment_rownnz_out[worldid, actid] = ten_num - moment_rowadr_out[worldid, actid] = rowadr + rownnz_ten = ten_J_rownnz[tenid] + rowadr_ten = ten_J_rowadr[tenid] - for i in range(ten_num): - dofadr = jnt_dofadr[wrap_objid[adr + i]] - sparseid = rowadr + i - moment_colind_out[worldid, sparseid] = dofadr - actuator_moment_out[worldid, sparseid] = ( - ten_J_in[worldid, tenid, dofadr] * gear0 - ) - else: # spatial - # TODO(team): sparse tendon jacobian - ten_nnz = int(0) - for dofadr in range(nv): - if ten_J_in[worldid, tenid, dofadr] != 0.0: - ten_nnz += 1 - rowadr = wp.atomic_add(moment_nnz, worldid, ten_nnz) - moment_rownnz_out[worldid, actid] = ten_nnz - moment_rowadr_out[worldid, actid] = rowadr - ptr = int(0) - for dofadr in range(nv): - J = ten_J_in[worldid, tenid, dofadr] - if J != 0.0: - sparseid = rowadr + ptr - moment_colind_out[worldid, sparseid] = dofadr - actuator_moment_out[worldid, sparseid] = J * gear0 - ptr += 1 + rowadr_mom = wp.atomic_add(moment_nnz, worldid, rownnz_ten) + moment_rownnz_out[worldid, actid] = rownnz_ten + moment_rowadr_out[worldid, actid] = rowadr_mom + + for k in range(rownnz_ten): + sparseid_ten = rowadr_ten + k + sparseid_mom = rowadr_mom + k + moment_colind_out[worldid, sparseid_mom] = ten_J_colind[sparseid_ten] + actuator_moment_out[worldid, sparseid_mom] = ten_J_in[worldid, sparseid_ten] * gear0 elif trntype == TrnType.BODY: # cannot compute meaningful length, set to zero actuator_length_out[worldid, actid] = 0.0 @@ -2195,19 +2310,17 @@ def _transmission( ptr = ndof - 1 while da >= 0: jacp, jacr = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - site_xpos_in[worldid, siteid], - site_bodyid[siteid], - da, - worldid, - ) - moment = wp.dot(jacp, wrench_translation) + wp.dot( - jacr, wrench_rotation + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + site_xpos_in[worldid, siteid], + site_bodyid[siteid], + da, + worldid, ) + moment = wp.dot(jacp, wrench_translation) + wp.dot(jacr, wrench_rotation) sparseid = rowadr + ptr moment_colind_out[worldid, sparseid] = da actuator_moment_out[worldid, sparseid] = moment @@ -2306,26 +2419,10 @@ def _transmission( break jacp, jacr = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - site_xpos, - site_bodyid[siteid], - da, - worldid, + body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_xpos, site_bodyid[siteid], da, worldid ) jacpref, jacrref = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - ref_xpos, - site_bodyid[refid], - da, - worldid, + body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, ref_xpos, site_bodyid[refid], da, worldid ) moment = float(0.0) @@ -2349,37 +2446,37 @@ def _transmission( @wp.kernel def _transmission_body_moment( - # Model: - opt_cone: int, - body_parentid: wp.array(dtype=int), - body_rootid: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), - geom_bodyid: wp.array(dtype=int), - actuator_trnid: wp.array(dtype=wp.vec2i), - actuator_trntype_body_adr: wp.array(dtype=int), - # Data in: - subtree_com_in: wp.array2d(dtype=wp.vec3), - cdof_in: wp.array2d(dtype=wp.spatial_vector), - moment_rowadr_in: wp.array2d(dtype=int), - contact_dist_in: wp.array(dtype=float), - contact_pos_in: wp.array(dtype=wp.vec3), - contact_frame_in: wp.array(dtype=wp.mat33), - contact_includemargin_in: wp.array(dtype=float), - contact_dim_in: wp.array(dtype=int), - contact_geom_in: wp.array(dtype=wp.vec2i), - contact_efc_address_in: wp.array2d(dtype=int), - contact_worldid_in: wp.array(dtype=int), - efc_J_rownnz_in: wp.array2d(dtype=int), - efc_J_rowadr_in: wp.array2d(dtype=int), - efc_J_colind_in: wp.array3d(dtype=int), - efc_J_in: wp.array3d(dtype=float), - nacon_in: wp.array(dtype=int), - # In: - efc_is_sparse: bool, - # Data out: - actuator_moment_out: wp.array2d(dtype=float), - # Out: - actuator_trntype_body_ncon_out: wp.array2d(dtype=int), + # Model: + opt_cone: int, + body_parentid: wp.array(dtype=int), + body_rootid: wp.array(dtype=int), + dof_bodyid: wp.array(dtype=int), + geom_bodyid: wp.array(dtype=int), + actuator_trnid: wp.array(dtype=wp.vec2i), + actuator_trntype_body_adr: wp.array(dtype=int), + # Data in: + subtree_com_in: wp.array2d(dtype=wp.vec3), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + moment_rowadr_in: wp.array2d(dtype=int), + contact_dist_in: wp.array(dtype=float), + contact_pos_in: wp.array(dtype=wp.vec3), + contact_frame_in: wp.array(dtype=wp.mat33), + contact_includemargin_in: wp.array(dtype=float), + contact_dim_in: wp.array(dtype=int), + contact_geom_in: wp.array(dtype=wp.vec2i), + contact_efc_address_in: wp.array2d(dtype=int), + contact_worldid_in: wp.array(dtype=int), + efc_J_rownnz_in: wp.array2d(dtype=int), + efc_J_rowadr_in: wp.array2d(dtype=int), + efc_J_colind_in: wp.array3d(dtype=int), + efc_J_in: wp.array3d(dtype=float), + nacon_in: wp.array(dtype=int), + # In: + efc_is_sparse: bool, + # Data out: + actuator_moment_out: wp.array2d(dtype=float), + # Out: + actuator_trntype_body_ncon_out: wp.array2d(dtype=int), ): trnbodyid, conid, dofid = wp.tid() actid = actuator_trntype_body_adr[trnbodyid] @@ -2427,20 +2524,12 @@ def _transmission_body_moment( efc_rowadr = efc_J_rowadr_in[worldid, efcid0] efc_sparseid = efc_rowadr + dofid colind = efc_J_colind_in[worldid, 0, efc_sparseid] - wp.atomic_add( - actuator_moment_out[worldid], - rowadr + colind, - efc_J_in[worldid, 0, efc_sparseid], - ) + wp.atomic_add(actuator_moment_out[worldid], rowadr + colind, efc_J_in[worldid, 0, efc_sparseid]) else: return else: colind = dofid - wp.atomic_add( - actuator_moment_out[worldid], - rowadr + colind, - efc_J_in[worldid, efcid0, dofid], - ) + wp.atomic_add(actuator_moment_out[worldid], rowadr + colind, efc_J_in[worldid, efcid0, dofid]) else: npyramid = contact_dim - 1 # number of frictional directions efc_force = 0.5 / float(npyramid) @@ -2453,20 +2542,12 @@ def _transmission_body_moment( efc_rowadr = efc_J_rowadr_in[worldid, efcid] efc_sparseid = efc_rowadr + dofid colind = efc_J_colind_in[worldid, 0, efc_sparseid] - wp.atomic_add( - actuator_moment_out[worldid], - rowadr + colind, - efc_J_in[worldid, 0, efc_sparseid] * efc_force, - ) + wp.atomic_add(actuator_moment_out[worldid], rowadr + colind, efc_J_in[worldid, 0, efc_sparseid] * efc_force) else: return else: colind = dofid - wp.atomic_add( - actuator_moment_out[worldid], - rowadr + colind, - efc_J_in[worldid, efcid, dofid] * efc_force, - ) + wp.atomic_add(actuator_moment_out[worldid], rowadr + colind, efc_J_in[worldid, efcid, dofid] * efc_force) # excluded contact in gap: get Jacobian, accumulate elif contact_exclude == 1: @@ -2487,46 +2568,28 @@ def _transmission_body_moment( colind = dofid jacp1, _ = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - contact_pos, - b1, - colind, - worldid, + body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, contact_pos, b1, colind, worldid ) jacp2, _ = support.jac_dof( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - contact_pos, - b2, - colind, - worldid, + body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, contact_pos, b2, colind, worldid ) jacdif = jacp2 - jacp1 # project Jacobian along the normal of the contact frame - wp.atomic_add( - actuator_moment_out[worldid], rowadr + colind, wp.dot(normal, jacdif) - ) + wp.atomic_add(actuator_moment_out[worldid], rowadr + colind, wp.dot(normal, jacdif)) @wp.kernel def _transmission_body_moment_scale( - # Model: - actuator_trntype_body_adr: wp.array(dtype=int), - # Data in: - moment_rowadr_in: wp.array2d(dtype=int), - # In: - actuator_trntype_body_ncon_in: wp.array2d(dtype=int), - # Data out: - actuator_moment_out: wp.array2d(dtype=float), + # Model: + actuator_trntype_body_adr: wp.array(dtype=int), + # Data in: + moment_rowadr_in: wp.array2d(dtype=int), + # In: + actuator_trntype_body_ncon_in: wp.array2d(dtype=int), + # Data out: + actuator_moment_out: wp.array2d(dtype=float), ): worldid, trnbodyid, dofid = wp.tid() @@ -2549,47 +2612,40 @@ def transmission(m: Model, d: Data): moment_nnz = wp.zeros((d.nworld,), dtype=int) wp.launch( - _transmission, - dim=(d.nworld, m.nu), - inputs=[ - m.nv, - m.body_parentid, - m.body_rootid, - m.body_weldid, - m.body_dofnum, - m.body_dofadr, - m.jnt_type, - m.jnt_qposadr, - m.jnt_dofadr, - m.dof_bodyid, - m.dof_parentid, - m.site_bodyid, - m.site_quat, - m.tendon_adr, - m.tendon_num, - m.wrap_type, - m.wrap_objid, - m.actuator_trntype, - m.actuator_trnid, - m.actuator_gear, - m.actuator_cranklength, - d.qpos, - d.xquat, - d.site_xpos, - d.site_xmat, - d.subtree_com, - d.cdof, - d.ten_J, - d.ten_length, - moment_nnz, - ], - outputs=[ - d.actuator_length, - d.moment_rownnz, - d.moment_rowadr, - d.moment_colind, - d.actuator_moment, - ], + _transmission, + dim=(d.nworld, m.nu), + inputs=[ + m.nv, + m.body_parentid, + m.body_rootid, + m.body_weldid, + m.body_dofnum, + m.body_dofadr, + m.jnt_type, + m.jnt_qposadr, + m.jnt_dofadr, + m.dof_bodyid, + m.dof_parentid, + m.site_bodyid, + m.site_quat, + m.ten_J_rownnz, + m.ten_J_rowadr, + m.ten_J_colind, + m.actuator_trntype, + m.actuator_trnid, + m.actuator_gear, + m.actuator_cranklength, + d.qpos, + d.xquat, + d.site_xpos, + d.site_xmat, + d.subtree_com, + d.cdof, + d.ten_J, + d.ten_length, + moment_nnz, + ], + outputs=[d.actuator_length, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment], ) if m.nacttrnbody: @@ -2597,83 +2653,105 @@ def transmission(m: Model, d: Data): ncon = wp.zeros((d.nworld, m.nacttrnbody), dtype=int) wp.launch( - _transmission_body_moment, - dim=(m.nacttrnbody, d.naconmax, m.nv), - inputs=[ - m.opt.cone, - m.body_parentid, - m.body_rootid, - m.dof_bodyid, - m.geom_bodyid, - m.actuator_trnid, - m.actuator_trntype_body_adr, - d.subtree_com, - d.cdof, - d.moment_rowadr, - d.contact.dist, - d.contact.pos, - d.contact.frame, - d.contact.includemargin, - d.contact.dim, - d.contact.geom, - d.contact.efc_address, - d.contact.worldid, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.nacon, - SPARSE_CONSTRAINT_JACOBIAN, - ], - outputs=[d.actuator_moment, ncon], + _transmission_body_moment, + dim=(m.nacttrnbody, d.naconmax, m.nv), + inputs=[ + m.opt.cone, + m.body_parentid, + m.body_rootid, + m.dof_bodyid, + m.geom_bodyid, + m.actuator_trnid, + m.actuator_trntype_body_adr, + d.subtree_com, + d.cdof, + d.moment_rowadr, + d.contact.dist, + d.contact.pos, + d.contact.frame, + d.contact.includemargin, + d.contact.dim, + d.contact.geom, + d.contact.efc_address, + d.contact.worldid, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.nacon, + m.is_sparse, + ], + outputs=[d.actuator_moment, ncon], ) # scale moments wp.launch( - _transmission_body_moment_scale, - dim=(d.nworld, m.nacttrnbody, m.nv), - inputs=[m.actuator_trntype_body_adr, d.moment_rowadr, ncon], - outputs=[d.actuator_moment], + _transmission_body_moment_scale, + dim=(d.nworld, m.nacttrnbody, m.nv), + inputs=[m.actuator_trntype_body_adr, d.moment_rowadr, ncon], + outputs=[d.actuator_moment], ) -@wp.kernel -def _solve_LD_sparse_x_acc_up( - # In: - L: wp.array3d(dtype=float), - qLD_updates_: wp.array(dtype=wp.vec3i), - # Out: - x: wp.array2d(dtype=float), -): - worldid, nodeid = wp.tid() - update = qLD_updates_[nodeid] - i, k, Madr_ki = update[0], update[1], update[2] - wp.atomic_sub(x[worldid], i, L[worldid, 0, Madr_ki] * x[worldid, k]) - - -@wp.kernel -def _solve_LD_sparse_qLDiag_mul( - # In: - D: wp.array2d(dtype=float), - # Out: - out: wp.array2d(dtype=float), -): - worldid, dofid = wp.tid() - out[worldid, dofid] *= D[worldid, dofid] +@cache_kernel +def _solve_LD_sparse_fused(nv: int, nlevels: int): + """Fused sparse backsubstitution: UP + diag + DOWN in one kernel.""" + @wp.func_native(snippet="WP_TILE_SYNC();") + def _syncthreads(): + pass -@wp.kernel -def _solve_LD_sparse_x_acc_down( - # In: - L: wp.array3d(dtype=float), - qLD_updates_: wp.array(dtype=wp.vec3i), - # Out: - x: wp.array2d(dtype=float), -): - worldid, nodeid = wp.tid() - update = qLD_updates_[nodeid] - i, k, Madr_ki = update[0], update[1], update[2] - wp.atomic_sub(x[worldid], k, L[worldid, 0, Madr_ki] * x[worldid, i]) + @wp.kernel(module="unique", enable_backward=False) + def kernel( + # In: + L: wp.array3d(dtype=float), + D: wp.array2d(dtype=float), + all_updates: wp.array(dtype=wp.vec3i), + level_offsets: wp.array(dtype=int), + y: wp.array2d(dtype=float), + # Out: + x_out: wp.array2d(dtype=float), + ): + worldid, tid = wp.tid() + NV = wp.static(nv) + NLEVELS = wp.static(nlevels) + BLOCK_DIM = wp.block_dim() + + # Copy y to x_out + for dofid in range(tid, NV, BLOCK_DIM): + x_out[worldid, dofid] = y[worldid, dofid] + _syncthreads() + + # Forward substitution + for level in range(NLEVELS): + level_idx = NLEVELS - 1 - level + level_offset = level_offsets[level_idx] + level_size = level_offsets[level_idx + 1] - level_offset + + for u in range(tid, level_size, BLOCK_DIM): + update = all_updates[level_offset + u] + i, k, Madr_ki = update[0], update[1], update[2] + wp.atomic_sub(x_out[worldid], i, L[worldid, 0, Madr_ki] * x_out[worldid, k]) + _syncthreads() + + # Diagonal multiply + for dofid in range(tid, NV, BLOCK_DIM): + x_out[worldid, dofid] *= D[worldid, dofid] + _syncthreads() + + # Backward substitution + for level in range(NLEVELS): + level_idx = level + level_offset = level_offsets[level_idx] + level_size = level_offsets[level_idx + 1] - level_offset + + for u in range(tid, level_size, BLOCK_DIM): + update = all_updates[level_offset + u] + i, k, Madr_ki = update[0], update[1], update[2] + wp.atomic_sub(x_out[worldid], k, L[worldid, 0, Madr_ki] * x_out[worldid, i]) + _syncthreads() + + return kernel def _solve_LD_sparse( @@ -2685,14 +2763,20 @@ def _solve_LD_sparse( y: wp.array2d(dtype=float), ): """Computes sparse backsubstitution: x = inv(L'*D*L)*y.""" - wp.copy(x, y) - for qLD_updates in reversed(m.qLD_updates): - wp.launch(_solve_LD_sparse_x_acc_up, dim=(d.nworld, qLD_updates.size), inputs=[L, qLD_updates], outputs=[x]) - - wp.launch(_solve_LD_sparse_qLDiag_mul, dim=(d.nworld, m.nv), inputs=[D], outputs=[x]) + nlevels = len(m.qLD_updates) + if wp.get_device().is_cuda: + dim_block = m.block_dim.solve_LD_sparse_fused + else: + # Fallback for CPU + dim_block = 1 - for qLD_updates in m.qLD_updates: - wp.launch(_solve_LD_sparse_x_acc_down, dim=(d.nworld, qLD_updates.size), inputs=[L, qLD_updates], outputs=[x]) + wp.launch( + _solve_LD_sparse_fused(m.nv, nlevels), + dim=(d.nworld, dim_block), + inputs=[L, D, m.qLD_all_updates, m.qLD_level_offsets, y], + outputs=[x], + block_dim=dim_block, + ) @cache_kernel @@ -3005,6 +3089,9 @@ def _joint_tendon( # Model: jnt_qposadr: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), wrap_objid: wp.array(dtype=int), wrap_prm: wp.array(dtype=float), tendon_jnt_adr: wp.array(dtype=int), @@ -3012,34 +3099,87 @@ def _joint_tendon( # Data in: qpos_in: wp.array2d(dtype=float), # Data out: - ten_J_out: wp.array3d(dtype=float), + ten_J_out: wp.array2d(dtype=float), ten_length_out: wp.array2d(dtype=float), ): worldid, wrapid = wp.tid() - tendon_jnt_adr_ = tendon_jnt_adr[wrapid] - wrap_jnt_adr_ = wrap_jnt_adr[wrapid] - - wrap_objid_ = wrap_objid[wrap_jnt_adr_] - prm = wrap_prm[wrap_jnt_adr_] + tenid = tendon_jnt_adr[wrapid] + wrapjntid = wrap_jnt_adr[wrapid] + wrapobjid = wrap_objid[wrapjntid] + prm = wrap_prm[wrapjntid] # add to length - L = prm * qpos_in[worldid, jnt_qposadr[wrap_objid_]] - # TODO(team): compare atomic_add and for loop - wp.atomic_add(ten_length_out[worldid], tendon_jnt_adr_, L) + L = prm * qpos_in[worldid, jnt_qposadr[wrapobjid]] + wp.atomic_add(ten_length_out[worldid], tenid, L) # add to moment - ten_J_out[worldid, tendon_jnt_adr_, jnt_dofadr[wrap_objid_]] = prm + dofadr = jnt_dofadr[wrapobjid] + rowadr = ten_J_rowadr[tenid] + rownnz = ten_J_rownnz[tenid] + for k in range(rownnz): + if ten_J_colind[rowadr + k] == dofadr: + ten_J_out[worldid, rowadr + k] = prm + break + + +@wp.func +def _accumulate_jac_chain( + # Model: + body_parentid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + # Data in: + cdof_in: wp.array2d(dtype=wp.spatial_vector), + # In: + offset: wp.vec3, + vec: wp.vec3, + bodyid: int, + rowadr: int, + rownnz: int, + scale: float, + worldid: int, + # Data out: + ten_J_out: wp.array2d(dtype=float), +): + """Walk body chain from bodyid to root, accumulate Jacobian contributions.""" + ptr = rownnz - 1 + bid = bodyid + while bid > 0: + bdofadr = body_dofadr[bid] + bdofnum = body_dofnum[bid] + # iterate DOFs in this body in descending order + for k_rev in range(bdofnum): + dof = bdofadr + bdofnum - 1 - k_rev + # scan pointer backward to find matching colind entry + while ptr >= 0: + sparseid = rowadr + ptr + if ten_J_colind[sparseid] <= dof: + break + ptr -= 1 + if ptr >= 0 and ten_J_colind[sparseid] == dof: + cdof = cdof_in[worldid, dof] + cdof_ang = wp.spatial_top(cdof) + cdof_lin = wp.spatial_bottom(cdof) + jacp = cdof_lin + wp.cross(cdof_ang, offset) + J = wp.dot(jacp, vec) * scale + if J != 0.0: + wp.atomic_add(ten_J_out[worldid], sparseid, J) + bid = body_parentid[bid] @wp.kernel def _spatial_site_tendon( # Model: - nv: int, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), wrap_objid: wp.array(dtype=int), tendon_site_pair_adr: wp.array(dtype=int), wrap_site_pair_adr: wp.array(dtype=int), @@ -3049,14 +3189,14 @@ def _spatial_site_tendon( subtree_com_in: wp.array2d(dtype=wp.vec3), cdof_in: wp.array2d(dtype=wp.spatial_vector), # Data out: - ten_J_out: wp.array3d(dtype=float), + ten_J_out: wp.array2d(dtype=float), ten_length_out: wp.array2d(dtype=float), ): worldid, elementid = wp.tid() # site pairs site_pair_adr = wrap_site_pair_adr[elementid] - ten_adr = tendon_site_pair_adr[elementid] + tenid = tendon_site_pair_adr[elementid] # pulley scaling pulley_scale = wrap_pulley_scale[site_pair_adr] @@ -3068,7 +3208,7 @@ def _spatial_site_tendon( pnt1 = site_xpos_in[worldid, id1] dif = pnt1 - pnt0 vec, length = math.normalize_with_norm(dif) - wp.atomic_add(ten_length_out[worldid], ten_adr, length * pulley_scale) + wp.atomic_add(ten_length_out[worldid], tenid, length * pulley_scale) if length < MJ_MINVAL: vec = wp.vec3(1.0, 0.0, 0.0) @@ -3076,26 +3216,55 @@ def _spatial_site_tendon( body0 = site_bodyid[id0] body1 = site_bodyid[id1] if body0 != body1: - # TODO(team): parallelize - for i in range(nv): - jacp1, _ = support.jac_dof(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pnt0, body0, i, worldid) - jacp2, _ = support.jac_dof(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pnt1, body1, i, worldid) - - J = wp.dot(jacp2 - jacp1, vec) - if J: - wp.atomic_add(ten_J_out[worldid, ten_adr], i, J * pulley_scale) + rownnz = ten_J_rownnz[tenid] + rowadr = ten_J_rowadr[tenid] + offset0 = pnt0 - subtree_com_in[worldid, body_rootid[body0]] + offset1 = pnt1 - subtree_com_in[worldid, body_rootid[body1]] + _accumulate_jac_chain( + body_parentid, + body_dofnum, + body_dofadr, + ten_J_colind, + cdof_in, + offset0, + vec, + body0, + rowadr, + rownnz, + -pulley_scale, + worldid, + ten_J_out, + ) + _accumulate_jac_chain( + body_parentid, + body_dofnum, + body_dofadr, + ten_J_colind, + cdof_in, + offset1, + vec, + body1, + rowadr, + rownnz, + pulley_scale, + worldid, + ten_J_out, + ) @wp.kernel def _spatial_geom_tendon( # Model: - nv: int, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), geom_bodyid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), site_bodyid: wp.array(dtype=int), + ten_J_rownnz: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), wrap_type: wp.array(dtype=int), wrap_objid: wp.array(dtype=int), wrap_prm: wp.array(dtype=float), @@ -3109,14 +3278,14 @@ def _spatial_geom_tendon( subtree_com_in: wp.array2d(dtype=wp.vec3), cdof_in: wp.array2d(dtype=wp.spatial_vector), # Data out: - ten_J_out: wp.array3d(dtype=float), + ten_J_out: wp.array2d(dtype=float), ten_length_out: wp.array2d(dtype=float), # Out: wrap_geom_xpos_out: wp.array2d(dtype=wp.spatial_vector), ): worldid, elementid = wp.tid() wrap_adr = wrap_geom_adr[elementid] - ten_adr = tendon_geom_adr[elementid] + tenid = tendon_geom_adr[elementid] # pulley scaling pulley_scale = wrap_pulley_scale[wrap_adr] @@ -3154,6 +3323,9 @@ def _spatial_geom_tendon( # store geom points wrap_geom_xpos_out[worldid, elementid] = wp.spatial_vector(geom_pnt0, geom_pnt1) + rownnz = ten_J_rownnz[tenid] + rowadr = ten_J_rowadr[tenid] + if length_geomgeom >= 0.0: dif_sitegeom = geom_pnt0 - site_pnt0 dif_geomsite = site_pnt1 - geom_pnt1 @@ -3164,7 +3336,7 @@ def _spatial_geom_tendon( length_sitegeomsite = length_sitegeom + length_geomgeom + length_geomsite if length_sitegeomsite: - wp.atomic_add(ten_length_out[worldid], ten_adr, length_sitegeomsite * pulley_scale) + wp.atomic_add(ten_length_out[worldid], tenid, length_sitegeomsite * pulley_scale) # moment if length_sitegeom < MJ_MINVAL: @@ -3176,61 +3348,120 @@ def _spatial_geom_tendon( dif_body_sitegeom = bodyid_site0 != bodyid_geom dif_body_geomsite = bodyid_geom != bodyid_site1 - # TODO(team): parallelize - for i in range(nv): - J = float(0.0) - # site-geom - if dif_body_sitegeom: - jacp_site0, _ = support.jac_dof( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_pnt0, bodyid_site0, i, worldid - ) - - jacp_geom0, _ = support.jac_dof( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, geom_pnt0, bodyid_geom, i, worldid - ) - - J += wp.dot(jacp_geom0 - jacp_site0, vec_sitegeom) - - # geom-site - if dif_body_geomsite: - jacp_geom1, _ = support.jac_dof( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, geom_pnt1, bodyid_geom, i, worldid - ) - - jacp_site1, _ = support.jac_dof( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_pnt1, bodyid_site1, i, worldid - ) - - J += wp.dot(jacp_site1 - jacp_geom1, vec_geomsite) + # site-geom segment + if dif_body_sitegeom: + offset_site0 = site_pnt0 - subtree_com_in[worldid, body_rootid[bodyid_site0]] + offset_geom0 = geom_pnt0 - subtree_com_in[worldid, body_rootid[bodyid_geom]] + _accumulate_jac_chain( + body_parentid, + body_dofnum, + body_dofadr, + ten_J_colind, + cdof_in, + offset_site0, + vec_sitegeom, + bodyid_site0, + rowadr, + rownnz, + -pulley_scale, + worldid, + ten_J_out, + ) + _accumulate_jac_chain( + body_parentid, + body_dofnum, + body_dofadr, + ten_J_colind, + cdof_in, + offset_geom0, + vec_sitegeom, + bodyid_geom, + rowadr, + rownnz, + pulley_scale, + worldid, + ten_J_out, + ) - if J: - wp.atomic_add(ten_J_out[worldid, ten_adr], i, J * pulley_scale) + # geom-site segment + if dif_body_geomsite: + offset_geom1 = geom_pnt1 - subtree_com_in[worldid, body_rootid[bodyid_geom]] + offset_site1 = site_pnt1 - subtree_com_in[worldid, body_rootid[bodyid_site1]] + _accumulate_jac_chain( + body_parentid, + body_dofnum, + body_dofadr, + ten_J_colind, + cdof_in, + offset_geom1, + vec_geomsite, + bodyid_geom, + rowadr, + rownnz, + -pulley_scale, + worldid, + ten_J_out, + ) + _accumulate_jac_chain( + body_parentid, + body_dofnum, + body_dofadr, + ten_J_colind, + cdof_in, + offset_site1, + vec_geomsite, + bodyid_site1, + rowadr, + rownnz, + pulley_scale, + worldid, + ten_J_out, + ) else: dif_sitesite = site_pnt1 - site_pnt0 vec_sitesite, length_sitesite = math.normalize_with_norm(dif_sitesite) # length if length_sitesite: - wp.atomic_add(ten_length_out[worldid], ten_adr, length_sitesite * pulley_scale) + wp.atomic_add(ten_length_out[worldid], tenid, length_sitesite * pulley_scale) # moment if length_sitesite < MJ_MINVAL: vec_sitesite = wp.vec3(1.0, 0.0, 0.0) if bodyid_site0 != bodyid_site1: - # TODO(team): parallelize - for i in range(nv): - jacp1, _ = support.jac_dof( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_pnt0, bodyid_site0, i, worldid - ) - jacp2, _ = support.jac_dof( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_pnt1, bodyid_site1, i, worldid - ) - - J = wp.dot(jacp2 - jacp1, vec_sitesite) - - if J: - wp.atomic_add(ten_J_out[worldid, ten_adr], i, J * pulley_scale) + offset_site0 = site_pnt0 - subtree_com_in[worldid, body_rootid[bodyid_site0]] + offset_site1 = site_pnt1 - subtree_com_in[worldid, body_rootid[bodyid_site1]] + _accumulate_jac_chain( + body_parentid, + body_dofnum, + body_dofadr, + ten_J_colind, + cdof_in, + offset_site0, + vec_sitesite, + bodyid_site0, + rowadr, + rownnz, + -pulley_scale, + worldid, + ten_J_out, + ) + _accumulate_jac_chain( + body_parentid, + body_dofnum, + body_dofadr, + ten_J_colind, + cdof_in, + offset_site1, + vec_sitesite, + bodyid_site1, + rowadr, + rownnz, + pulley_scale, + worldid, + ten_J_out, + ) @wp.kernel @@ -3412,7 +3643,18 @@ def tendon(m: Model, d: Data): wp.launch( _joint_tendon, dim=(d.nworld, m.wrap_jnt_adr.size), - inputs=[m.jnt_qposadr, m.jnt_dofadr, m.wrap_objid, m.wrap_prm, m.tendon_jnt_adr, m.wrap_jnt_adr, d.qpos], + inputs=[ + m.jnt_qposadr, + m.jnt_dofadr, + m.ten_J_rownnz, + m.ten_J_rowadr, + m.ten_J_colind, + m.wrap_objid, + m.wrap_prm, + m.tendon_jnt_adr, + m.wrap_jnt_adr, + d.qpos, + ], outputs=[d.ten_J, d.ten_length], ) @@ -3428,11 +3670,14 @@ def tendon(m: Model, d: Data): _spatial_site_tendon, dim=(d.nworld, m.wrap_site_pair_adr.size), inputs=[ - m.nv, m.body_parentid, m.body_rootid, - m.dof_bodyid, + m.body_dofnum, + m.body_dofadr, m.site_bodyid, + m.ten_J_rownnz, + m.ten_J_rowadr, + m.ten_J_colind, m.wrap_objid, m.tendon_site_pair_adr, m.wrap_site_pair_adr, @@ -3449,13 +3694,16 @@ def tendon(m: Model, d: Data): _spatial_geom_tendon, dim=(d.nworld, m.wrap_geom_adr.size), inputs=[ - m.nv, m.body_parentid, m.body_rootid, - m.dof_bodyid, + m.body_dofnum, + m.body_dofadr, m.geom_bodyid, m.geom_size, m.site_bodyid, + m.ten_J_rownnz, + m.ten_J_rowadr, + m.ten_J_colind, m.wrap_type, m.wrap_objid, m.wrap_prm, diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/solver.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/solver.py index 82ac23c7c0..2fabebe0ad 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/solver.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/solver.py @@ -17,17 +17,17 @@ from math import ceil from math import sqrt +import warp as wp + from mujoco.mjx.third_party.mujoco_warp._src import math from mujoco.mjx.third_party.mujoco_warp._src import smooth from mujoco.mjx.third_party.mujoco_warp._src import support from mujoco.mjx.third_party.mujoco_warp._src import types from mujoco.mjx.third_party.mujoco_warp._src.block_cholesky import create_blocked_cholesky_func from mujoco.mjx.third_party.mujoco_warp._src.block_cholesky import create_blocked_cholesky_solve_func -from mujoco.mjx.third_party.mujoco_warp._src.types import SPARSE_CONSTRAINT_JACOBIAN from mujoco.mjx.third_party.mujoco_warp._src.warp_util import cache_kernel from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope from mujoco.mjx.third_party.mujoco_warp._src.warp_util import scoped_mathdx_gemm_disabled -import warp as wp wp.set_module_options({"enable_backward": False}) @@ -91,14 +91,14 @@ def create_inverse_context(m: types.Model, d: types.Data) -> InverseContext: njmax = d.njmax return InverseContext( - Jaref=wp.empty((nworld, njmax), dtype=float), - search_dot=wp.empty((nworld,), dtype=float), - gauss=wp.empty((nworld,), dtype=float), - cost=wp.empty((nworld,), dtype=float), - prev_cost=wp.empty((nworld,), dtype=float), - done=wp.empty((nworld,), dtype=bool), - changed_efc_ids=wp.empty((nworld, 0), dtype=int), - changed_efc_count=wp.empty((0,), dtype=int), + Jaref=wp.empty((nworld, njmax), dtype=float), + search_dot=wp.empty((nworld,), dtype=float), + gauss=wp.empty((nworld,), dtype=float), + cost=wp.empty((nworld,), dtype=float), + prev_cost=wp.empty((nworld,), dtype=float), + done=wp.empty((nworld,), dtype=bool), + changed_efc_ids=wp.empty((nworld, 0), dtype=int), + changed_efc_count=wp.empty((0,), dtype=int), ) @@ -121,36 +121,28 @@ def create_solver_context(m: types.Model, d: types.Data) -> SolverContext: alloc_hfactor = alloc_h and nv > _BLOCK_CHOLESKY_DIM return SolverContext( - Jaref=wp.empty((nworld, njmax), dtype=float), - search_dot=wp.empty((nworld,), dtype=float), - gauss=wp.empty((nworld,), dtype=float), - cost=wp.empty((nworld,), dtype=float), - prev_cost=wp.empty((nworld,), dtype=float), - done=wp.empty((nworld,), dtype=bool), - grad=wp.zeros((nworld, nv_pad), dtype=float), - grad_dot=wp.empty((nworld,), dtype=float), - Mgrad=wp.zeros((nworld, nv_pad), dtype=float), - search=wp.empty((nworld, nv), dtype=float), - mv=wp.empty((nworld, nv), dtype=float), - jv=wp.empty((nworld, njmax), dtype=float), - quad=wp.empty((nworld, njmax), dtype=wp.vec3), - quad_gauss=wp.empty((nworld,), dtype=wp.vec3), - alpha=wp.empty((nworld,), dtype=float), - prev_grad=wp.empty((nworld, nv), dtype=float), - prev_Mgrad=wp.empty((nworld, nv), dtype=float), - beta=wp.empty((nworld,), dtype=float), - h=wp.zeros((nworld, nv_pad, nv_pad), dtype=float) - if alloc_h - else wp.empty((nworld, 0, 0), dtype=float), - hfactor=wp.zeros((nworld, nv_pad, nv_pad), dtype=float) - if alloc_hfactor - else wp.empty((nworld, 0, 0), dtype=float), - changed_efc_ids=wp.empty((nworld, njmax), dtype=int) - if alloc_h - else wp.empty((nworld, 0), dtype=int), - changed_efc_count=wp.empty((nworld,), dtype=int) - if alloc_h - else wp.empty((0,), dtype=int), + Jaref=wp.empty((nworld, njmax), dtype=float), + search_dot=wp.empty((nworld,), dtype=float), + gauss=wp.empty((nworld,), dtype=float), + cost=wp.empty((nworld,), dtype=float), + prev_cost=wp.empty((nworld,), dtype=float), + done=wp.empty((nworld,), dtype=bool), + grad=wp.zeros((nworld, nv_pad), dtype=float), + grad_dot=wp.empty((nworld,), dtype=float), + Mgrad=wp.zeros((nworld, nv_pad), dtype=float), + search=wp.empty((nworld, nv), dtype=float), + mv=wp.empty((nworld, nv), dtype=float), + jv=wp.empty((nworld, njmax), dtype=float), + quad=wp.empty((nworld, njmax), dtype=wp.vec3), + quad_gauss=wp.empty((nworld,), dtype=wp.vec3), + alpha=wp.empty((nworld,), dtype=float), + prev_grad=wp.empty((nworld, nv), dtype=float), + prev_Mgrad=wp.empty((nworld, nv), dtype=float), + beta=wp.empty((nworld,), dtype=float), + h=wp.zeros((nworld, nv_pad, nv_pad), dtype=float) if alloc_h else wp.empty((nworld, 0, 0), dtype=float), + hfactor=wp.zeros((nworld, nv_pad, nv_pad), dtype=float) if alloc_hfactor else wp.empty((nworld, 0, 0), dtype=float), + changed_efc_ids=wp.empty((nworld, njmax), dtype=int) if alloc_h else wp.empty((nworld, 0), dtype=int), + changed_efc_count=wp.empty((nworld,), dtype=int) if alloc_h else wp.empty((0,), dtype=int), ) @@ -892,18 +884,19 @@ def _compute_efc_eval_pt_3alphas_elliptic( @cache_kernel -def linesearch_iterative(ls_iterations: int, cone_type: types.ConeType, fuse_jv: bool): +def linesearch_iterative(ls_iterations: int, cone_type: types.ConeType, fuse_jv: bool, is_sparse: bool): """Factory for iterative linesearch kernel. Args: - block_dim: Number of threads per block for tile reductions. ls_iterations: Max linesearch iterations (compile-time constant for loop optimization). cone_type: Friction cone type (PYRAMIDAL or ELLIPTIC) for compile-time optimization. fuse_jv: Whether to compute jv = J @ search in-kernel (efficient for small nv). + is_sparse: Use sparse matrix representation for constraint Jacobian. """ LS_ITERATIONS = ls_iterations IS_ELLIPTIC = cone_type == types.ConeType.ELLIPTIC FUSE_JV = fuse_jv + IS_SPARSE = is_sparse # Native snippet for CUDA __syncthreads() @wp.func_native(snippet="WP_TILE_SYNC();") @@ -922,46 +915,46 @@ def _syncthreads(): @wp.kernel(module="unique", enable_backward=False) def kernel( - # Model: - nv: int, - opt_tolerance: wp.array(dtype=float), - opt_ls_tolerance: wp.array(dtype=float), - opt_impratio_invsqrt: wp.array(dtype=float), - stat_meaninertia: wp.array(dtype=float), - # Data in: - ne_in: wp.array(dtype=int), - nf_in: wp.array(dtype=int), - nefc_in: wp.array(dtype=int), - qfrc_smooth_in: wp.array2d(dtype=float), - contact_friction_in: wp.array(dtype=types.vec5), - contact_dim_in: wp.array(dtype=int), - contact_efc_address_in: wp.array2d(dtype=int), - efc_type_in: wp.array2d(dtype=int), - efc_id_in: wp.array2d(dtype=int), - efc_J_rownnz_in: wp.array2d(dtype=int), - efc_J_rowadr_in: wp.array2d(dtype=int), - efc_J_colind_in: wp.array3d(dtype=int), - efc_J_in: wp.array3d(dtype=float), - efc_D_in: wp.array2d(dtype=float), - efc_frictionloss_in: wp.array2d(dtype=float), - njmax_in: int, - nacon_in: wp.array(dtype=int), - # In: - ctx_Jaref_in: wp.array2d(dtype=float), - ctx_search_in: wp.array2d(dtype=float), - ctx_search_dot_in: wp.array(dtype=float), - ctx_gauss_in: wp.array(dtype=float), - ctx_mv_in: wp.array2d(dtype=float), - ctx_jv_in: wp.array2d(dtype=float), - ctx_quad_in: wp.array2d(dtype=wp.vec3), - ctx_done_in: wp.array(dtype=bool), - # Data out: - qacc_out: wp.array2d(dtype=float), - efc_Ma_out: wp.array2d(dtype=float), - # Out: - ctx_Jaref_out: wp.array2d(dtype=float), - ctx_jv_out: wp.array2d(dtype=float), - ctx_quad_out: wp.array2d(dtype=wp.vec3), + # Model: + nv: int, + opt_tolerance: wp.array(dtype=float), + opt_ls_tolerance: wp.array(dtype=float), + opt_impratio_invsqrt: wp.array(dtype=float), + stat_meaninertia: wp.array(dtype=float), + # Data in: + ne_in: wp.array(dtype=int), + nf_in: wp.array(dtype=int), + nefc_in: wp.array(dtype=int), + qfrc_smooth_in: wp.array2d(dtype=float), + contact_friction_in: wp.array(dtype=types.vec5), + contact_dim_in: wp.array(dtype=int), + contact_efc_address_in: wp.array2d(dtype=int), + efc_type_in: wp.array2d(dtype=int), + efc_id_in: wp.array2d(dtype=int), + efc_J_rownnz_in: wp.array2d(dtype=int), + efc_J_rowadr_in: wp.array2d(dtype=int), + efc_J_colind_in: wp.array3d(dtype=int), + efc_J_in: wp.array3d(dtype=float), + efc_D_in: wp.array2d(dtype=float), + efc_frictionloss_in: wp.array2d(dtype=float), + njmax_in: int, + nacon_in: wp.array(dtype=int), + # In: + ctx_Jaref_in: wp.array2d(dtype=float), + ctx_search_in: wp.array2d(dtype=float), + ctx_search_dot_in: wp.array(dtype=float), + ctx_gauss_in: wp.array(dtype=float), + ctx_mv_in: wp.array2d(dtype=float), + ctx_jv_in: wp.array2d(dtype=float), + ctx_quad_in: wp.array2d(dtype=wp.vec3), + ctx_done_in: wp.array(dtype=bool), + # Data out: + qacc_out: wp.array2d(dtype=float), + efc_Ma_out: wp.array2d(dtype=float), + # Out: + ctx_Jaref_out: wp.array2d(dtype=float), + ctx_jv_out: wp.array2d(dtype=float), + ctx_quad_out: wp.array2d(dtype=wp.vec3), ): worldid, tid = wp.tid() @@ -976,15 +969,13 @@ def kernel( if wp.static(FUSE_JV): for efcid in range(tid, nefc, wp.block_dim()): jv = float(0.0) - if wp.static(SPARSE_CONSTRAINT_JACOBIAN): + if wp.static(IS_SPARSE): rownnz = efc_J_rownnz_in[worldid, efcid] rowadr = efc_J_rowadr_in[worldid, efcid] for k in range(rownnz): sparseid = rowadr + k colind = efc_J_colind_in[worldid, 0, sparseid] - jv += ( - efc_J_in[worldid, 0, sparseid] * ctx_search_in[worldid, colind] - ) + jv += efc_J_in[worldid, 0, sparseid] * ctx_search_in[worldid, colind] else: for i in range(nv): jv += efc_J_in[worldid, efcid, i] * ctx_search_in[worldid, i] @@ -1360,42 +1351,42 @@ def _linesearch_iterative(m: types.Model, d: types.Data, ctx: SolverContext, fus fuse_jv: Whether jv is computed in-kernel (True) or pre-computed (False). """ wp.launch_tiled( - linesearch_iterative(m.opt.ls_iterations, m.opt.cone, fuse_jv), - dim=d.nworld, - inputs=[ - m.nv, - m.opt.tolerance, - m.opt.ls_tolerance, - m.opt.impratio_invsqrt, - m.stat.meaninertia, - d.ne, - d.nf, - d.nefc, - d.qfrc_smooth, - d.contact.friction, - d.contact.dim, - d.contact.efc_address, - d.efc.type, - d.efc.id, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.D, - d.efc.frictionloss, - d.njmax, - d.nacon, - ctx.Jaref, - ctx.search, - ctx.search_dot, - ctx.gauss, - ctx.mv, - ctx.jv, - ctx.quad, - ctx.done, - ], - outputs=[d.qacc, d.efc.Ma, ctx.Jaref, ctx.jv, ctx.quad], - block_dim=m.block_dim.linesearch_iterative, + linesearch_iterative(m.opt.ls_iterations, m.opt.cone, fuse_jv, m.is_sparse), + dim=d.nworld, + inputs=[ + m.nv, + m.opt.tolerance, + m.opt.ls_tolerance, + m.opt.impratio_invsqrt, + m.stat.meaninertia, + d.ne, + d.nf, + d.nefc, + d.qfrc_smooth, + d.contact.friction, + d.contact.dim, + d.contact.efc_address, + d.efc.type, + d.efc.id, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.D, + d.efc.frictionloss, + d.njmax, + d.nacon, + ctx.Jaref, + ctx.search, + ctx.search_dot, + ctx.gauss, + ctx.mv, + ctx.jv, + ctx.quad, + ctx.done, + ], + outputs=[d.qacc, d.efc.Ma, ctx.Jaref, ctx.jv, ctx.quad], + block_dim=m.block_dim.linesearch_iterative, ) @@ -1420,20 +1411,20 @@ def linesearch_zero_jv( @cache_kernel -def linesearch_jv_fused(opt_is_sparse: bool, nv: int, dofs_per_thread: int): +def linesearch_jv_fused(is_sparse: bool, nv: int, dofs_per_thread: int): @wp.kernel(module="unique", enable_backward=False) def kernel( - # Data in: - nefc_in: wp.array(dtype=int), - efc_J_rownnz_in: wp.array2d(dtype=int), - efc_J_rowadr_in: wp.array2d(dtype=int), - efc_J_colind_in: wp.array3d(dtype=int), - efc_J_in: wp.array3d(dtype=float), - # In: - ctx_search_in: wp.array2d(dtype=float), - ctx_done_in: wp.array(dtype=bool), - # Out: - ctx_jv_out: wp.array2d(dtype=float), + # Data in: + nefc_in: wp.array(dtype=int), + efc_J_rownnz_in: wp.array2d(dtype=int), + efc_J_rowadr_in: wp.array2d(dtype=int), + efc_J_colind_in: wp.array3d(dtype=int), + efc_J_in: wp.array3d(dtype=float), + # In: + ctx_search_in: wp.array2d(dtype=float), + ctx_done_in: wp.array(dtype=bool), + # Out: + ctx_jv_out: wp.array2d(dtype=float), ): worldid, efcid, dofstart = wp.tid() @@ -1446,23 +1437,21 @@ def kernel( jv_out = float(0.0) if wp.static(dofs_per_thread >= nv): - if wp.static(SPARSE_CONSTRAINT_JACOBIAN): + if wp.static(is_sparse): # Sparse: iterate over non-zero entries in the row rownnz = efc_J_rownnz_in[worldid, efcid] rowadr = efc_J_rowadr_in[worldid, efcid] for k in range(rownnz): sparseid = rowadr + k colind = efc_J_colind_in[worldid, 0, sparseid] - jv_out += ( - efc_J_in[worldid, 0, sparseid] * ctx_search_in[worldid, colind] - ) + jv_out += efc_J_in[worldid, 0, sparseid] * ctx_search_in[worldid, colind] else: for i in range(wp.static(min(dofs_per_thread, nv))): jv_out += efc_J_in[worldid, efcid, i] * ctx_search_in[worldid, i] ctx_jv_out[worldid, efcid] = jv_out else: - if wp.static(SPARSE_CONSTRAINT_JACOBIAN): + if wp.static(is_sparse): # Sparse: thread 0 handles entire row (sparse entries << nv typically) if dofstart == 0: rownnz = efc_J_rownnz_in[worldid, efcid] @@ -1470,9 +1459,7 @@ def kernel( for k in range(rownnz): sparseid = rowadr + k colind = efc_J_colind_in[worldid, 0, sparseid] - jv_out += ( - efc_J_in[worldid, 0, sparseid] * ctx_search_in[worldid, colind] - ) + jv_out += efc_J_in[worldid, 0, sparseid] * ctx_search_in[worldid, colind] ctx_jv_out[worldid, efcid] = jv_out else: for i in range(wp.static(dofs_per_thread)): @@ -1583,10 +1570,7 @@ def linesearch_prepare_quad( dim = contact_dim_in[conid] friction = contact_friction_in[conid] - mu = ( - friction[0] - * opt_impratio_invsqrt[worldid % opt_impratio_invsqrt.shape[0]] - ) + mu = friction[0] * opt_impratio_invsqrt[worldid % opt_impratio_invsqrt.shape[0]] u0 = Jaref * mu v0 = jv * mu @@ -1707,18 +1691,10 @@ def _linesearch(m: types.Model, d: types.Data, ctx: SolverContext, cost: wp.arra ) wp.launch( - linesearch_jv_fused(m.is_sparse, m.nv, dofs_per_thread), - dim=(d.nworld, d.njmax, threads_per_efc), - inputs=[ - d.nefc, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - ctx.search, - ctx.done, - ], - outputs=[ctx.jv], + linesearch_jv_fused(m.is_sparse, m.nv, dofs_per_thread), + dim=(d.nworld, d.njmax, threads_per_efc), + inputs=[d.nefc, d.efc.J_rownnz, d.efc.J_rowadr, d.efc.J_colind, d.efc.J, ctx.search, ctx.done], + outputs=[ctx.jv], ) if m.opt.ls_parallel: @@ -1744,19 +1720,19 @@ def solve_init_efc( @cache_kernel -def solve_init_jaref(opt_is_sparse: bool, nv: int, dofs_per_thread: int): +def solve_init_jaref(is_sparse: bool, nv: int, dofs_per_thread: int): @wp.kernel(module="unique", enable_backward=False) def kernel( - # Data in: - nefc_in: wp.array(dtype=int), - qacc_in: wp.array2d(dtype=float), - efc_J_rownnz_in: wp.array2d(dtype=int), - efc_J_rowadr_in: wp.array2d(dtype=int), - efc_J_colind_in: wp.array3d(dtype=int), - efc_J_in: wp.array3d(dtype=float), - efc_aref_in: wp.array2d(dtype=float), - # Out: - ctx_Jaref_out: wp.array2d(dtype=float), + # Data in: + nefc_in: wp.array(dtype=int), + qacc_in: wp.array2d(dtype=float), + efc_J_rownnz_in: wp.array2d(dtype=int), + efc_J_rowadr_in: wp.array2d(dtype=int), + efc_J_colind_in: wp.array3d(dtype=int), + efc_J_in: wp.array3d(dtype=float), + efc_aref_in: wp.array2d(dtype=float), + # Out: + ctx_Jaref_out: wp.array2d(dtype=float), ): worldid, efcid, dofstart = wp.tid() @@ -1764,7 +1740,7 @@ def kernel( return jaref = float(0.0) - if wp.static(SPARSE_CONSTRAINT_JACOBIAN): + if wp.static(is_sparse): rownnz = efc_J_rownnz_in[worldid, efcid] rowadr = efc_J_rowadr_in[worldid, efcid] for i in range(rownnz): @@ -1785,9 +1761,7 @@ def kernel( jaref += efc_J_in[worldid, efcid, ii] * qacc_in[worldid, ii] if dofstart == 0: - wp.atomic_add( - ctx_Jaref_out, worldid, efcid, jaref - efc_aref_in[worldid, efcid] - ) + wp.atomic_add(ctx_Jaref_out, worldid, efcid, jaref - efc_aref_in[worldid, efcid]) else: wp.atomic_add(ctx_Jaref_out, worldid, efcid, jaref) @@ -1834,30 +1808,30 @@ def update_constraint_efc(track_changes: bool): @wp.kernel(module="unique", enable_backward=False) def kernel( - # Model: - opt_impratio_invsqrt: wp.array(dtype=float), - # Data in: - ne_in: wp.array(dtype=int), - nf_in: wp.array(dtype=int), - nefc_in: wp.array(dtype=int), - contact_friction_in: wp.array(dtype=types.vec5), - contact_dim_in: wp.array(dtype=int), - contact_efc_address_in: wp.array2d(dtype=int), - efc_type_in: wp.array2d(dtype=int), - efc_id_in: wp.array2d(dtype=int), - efc_D_in: wp.array2d(dtype=float), - efc_frictionloss_in: wp.array2d(dtype=float), - nacon_in: wp.array(dtype=int), - # In: - ctx_Jaref_in: wp.array2d(dtype=float), - ctx_done_in: wp.array(dtype=bool), - # Data out: - efc_force_out: wp.array2d(dtype=float), - efc_state_out: wp.array2d(dtype=int), - # Out: - ctx_cost_out: wp.array(dtype=float), - changed_ids_out: wp.array2d(dtype=int), - changed_count_out: wp.array(dtype=int), + # Model: + opt_impratio_invsqrt: wp.array(dtype=float), + # Data in: + ne_in: wp.array(dtype=int), + nf_in: wp.array(dtype=int), + nefc_in: wp.array(dtype=int), + contact_friction_in: wp.array(dtype=types.vec5), + contact_dim_in: wp.array(dtype=int), + contact_efc_address_in: wp.array2d(dtype=int), + efc_type_in: wp.array2d(dtype=int), + efc_id_in: wp.array2d(dtype=int), + efc_D_in: wp.array2d(dtype=float), + efc_frictionloss_in: wp.array2d(dtype=float), + nacon_in: wp.array(dtype=int), + # In: + ctx_Jaref_in: wp.array2d(dtype=float), + ctx_done_in: wp.array(dtype=bool), + # Data out: + efc_force_out: wp.array2d(dtype=float), + efc_state_out: wp.array2d(dtype=int), + # Out: + ctx_cost_out: wp.array(dtype=float), + changed_ids_out: wp.array2d(dtype=int), + changed_count_out: wp.array(dtype=int), ): worldid, efcid = wp.tid() @@ -1869,9 +1843,7 @@ def kernel( # Read old QUADRATIC status before overwriting if wp.static(TRACK_CHANGES): - old_quad = ( - efc_state_out[worldid, efcid] == types.ConstraintState.QUADRATIC.value - ) + old_quad = efc_state_out[worldid, efcid] == types.ConstraintState.QUADRATIC.value efc_D = efc_D_in[worldid, efcid] Jaref = ctx_Jaref_in[worldid, efcid] @@ -1919,10 +1891,7 @@ def kernel( dim = contact_dim_in[conid] friction = contact_friction_in[conid] - mu = ( - friction[0] - * opt_impratio_invsqrt[worldid % opt_impratio_invsqrt.shape[0]] - ) + mu = friction[0] * opt_impratio_invsqrt[worldid % opt_impratio_invsqrt.shape[0]] efcid0 = contact_efc_address_in[conid, 0] if efcid0 < 0: @@ -1984,17 +1953,17 @@ def kernel( @wp.kernel def update_constraint_init_qfrc_constraint_sparse( - # Data in: - nefc_in: wp.array(dtype=int), - efc_J_rownnz_in: wp.array2d(dtype=int), - efc_J_rowadr_in: wp.array2d(dtype=int), - efc_J_colind_in: wp.array3d(dtype=int), - efc_J_in: wp.array3d(dtype=float), - efc_force_in: wp.array2d(dtype=float), - # In: - ctx_done_in: wp.array(dtype=bool), - # Data out: - qfrc_constraint_out: wp.array2d(dtype=float), + # Data in: + nefc_in: wp.array(dtype=int), + efc_J_rownnz_in: wp.array2d(dtype=int), + efc_J_rowadr_in: wp.array2d(dtype=int), + efc_J_colind_in: wp.array3d(dtype=int), + efc_J_in: wp.array3d(dtype=float), + efc_force_in: wp.array2d(dtype=float), + # In: + ctx_done_in: wp.array(dtype=bool), + # Data out: + qfrc_constraint_out: wp.array2d(dtype=float), ): worldid, efcid = wp.tid() @@ -2017,15 +1986,15 @@ def update_constraint_init_qfrc_constraint_sparse( @wp.kernel def update_constraint_init_qfrc_constraint_dense( - # Data in: - nefc_in: wp.array(dtype=int), - efc_J_in: wp.array3d(dtype=float), - efc_force_in: wp.array2d(dtype=float), - njmax_in: int, - # In: - ctx_done_in: wp.array(dtype=bool), - # Data out: - qfrc_constraint_out: wp.array2d(dtype=float), + # Data in: + nefc_in: wp.array(dtype=int), + efc_J_in: wp.array3d(dtype=float), + efc_force_in: wp.array2d(dtype=float), + njmax_in: int, + # In: + ctx_done_in: wp.array(dtype=bool), + # Data out: + qfrc_constraint_out: wp.array2d(dtype=float), ): worldid, dofid = wp.tid() @@ -2076,23 +2045,23 @@ def kernel( gauss_cost += (efc_Ma_in[worldid, ii] - qfrc_smooth_in[worldid, ii]) * ( qacc_in[worldid, ii] - qacc_smooth_in[worldid, ii] ) - wp.atomic_add(ctx_gauss_out, worldid, gauss_cost) - wp.atomic_add(ctx_cost_out, worldid, gauss_cost) + wp.atomic_add(ctx_gauss_out, worldid, 0.5 * gauss_cost) + wp.atomic_add(ctx_cost_out, worldid, 0.5 * gauss_cost) return kernel @wp.kernel def update_gradient_h_incremental( - # Data in: - efc_J_in: wp.array3d(dtype=float), - efc_D_in: wp.array2d(dtype=float), - efc_state_in: wp.array2d(dtype=int), - # In: - changed_ids_in: wp.array2d(dtype=int), - changed_count_in: wp.array(dtype=int), - # Out: - ctx_h_out: wp.array3d(dtype=float), + # Data in: + efc_J_in: wp.array3d(dtype=float), + efc_D_in: wp.array2d(dtype=float), + efc_state_in: wp.array2d(dtype=int), + # In: + changed_ids_in: wp.array2d(dtype=int), + changed_count_in: wp.array(dtype=int), + # Out: + ctx_h_out: wp.array3d(dtype=float), ): """Incrementally update lower triangle of H for changed constraints. @@ -2131,18 +2100,18 @@ def update_gradient_h_incremental( @wp.kernel def update_gradient_h_incremental_sparse( - # Data in: - efc_J_rownnz_in: wp.array2d(dtype=int), - efc_J_rowadr_in: wp.array2d(dtype=int), - efc_J_colind_in: wp.array3d(dtype=int), - efc_J_in: wp.array3d(dtype=float), - efc_D_in: wp.array2d(dtype=float), - efc_state_in: wp.array2d(dtype=int), - # In: - changed_ids_in: wp.array2d(dtype=int), - changed_count_in: wp.array(dtype=int), - # Out: - ctx_h_out: wp.array3d(dtype=float), + # Data in: + efc_J_rownnz_in: wp.array2d(dtype=int), + efc_J_rowadr_in: wp.array2d(dtype=int), + efc_J_colind_in: wp.array3d(dtype=int), + efc_J_in: wp.array3d(dtype=float), + efc_D_in: wp.array2d(dtype=float), + efc_state_in: wp.array2d(dtype=int), + # In: + changed_ids_in: wp.array2d(dtype=int), + changed_count_in: wp.array(dtype=int), + # Out: + ctx_h_out: wp.array3d(dtype=float), ): """Incrementally update lower triangle of H for changed constraints (sparse J).""" worldid, change_idx = wp.tid() @@ -2182,12 +2151,7 @@ def update_gradient_h_incremental_sparse( wp.atomic_add(ctx_h_out[worldid, colindj], colindi, h) -def _update_constraint( - m: types.Model, - d: types.Data, - ctx: SolverContext | InverseContext, - track_changes: bool = False, -): +def _update_constraint(m: types.Model, d: types.Data, ctx: SolverContext | InverseContext, track_changes: bool = False): """Update constraint arrays after each solve iteration.""" wp.launch( update_constraint_init_cost, @@ -2197,58 +2161,44 @@ def _update_constraint( ) efc_inputs = [ - m.opt.impratio_invsqrt, - d.ne, - d.nf, - d.nefc, - d.contact.friction, - d.contact.dim, - d.contact.efc_address, - d.efc.type, - d.efc.id, - d.efc.D, - d.efc.frictionloss, - d.nacon, - ctx.Jaref, - ctx.done, + m.opt.impratio_invsqrt, + d.ne, + d.nf, + d.nefc, + d.contact.friction, + d.contact.dim, + d.contact.efc_address, + d.efc.type, + d.efc.id, + d.efc.D, + d.efc.frictionloss, + d.nacon, + ctx.Jaref, + ctx.done, ] wp.launch( - update_constraint_efc(track_changes), - dim=(d.nworld, d.njmax), - inputs=efc_inputs, - outputs=[ - d.efc.force, - d.efc.state, - ctx.cost, - ctx.changed_efc_ids, - ctx.changed_efc_count, - ], + update_constraint_efc(track_changes), + dim=(d.nworld, d.njmax), + inputs=efc_inputs, + outputs=[d.efc.force, d.efc.state, ctx.cost, ctx.changed_efc_ids, ctx.changed_efc_count], ) # qfrc_constraint = efc_J.T @ efc_force - if SPARSE_CONSTRAINT_JACOBIAN: + if m.is_sparse: d.qfrc_constraint.zero_() wp.launch( - update_constraint_init_qfrc_constraint_sparse, - dim=(d.nworld, d.njmax), - inputs=[ - d.nefc, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.force, - ctx.done, - ], - outputs=[d.qfrc_constraint], + update_constraint_init_qfrc_constraint_sparse, + dim=(d.nworld, d.njmax), + inputs=[d.nefc, d.efc.J_rownnz, d.efc.J_rowadr, d.efc.J_colind, d.efc.J, d.efc.force, ctx.done], + outputs=[d.qfrc_constraint], ) else: wp.launch( - update_constraint_init_qfrc_constraint_dense, - dim=(d.nworld, m.nv), - inputs=[d.nefc, d.efc.J, d.efc.force, d.njmax, ctx.done], - outputs=[d.qfrc_constraint], + update_constraint_init_qfrc_constraint_dense, + dim=(d.nworld, m.nv), + inputs=[d.nefc, d.efc.J, d.efc.force, d.njmax, ctx.done], + outputs=[d.qfrc_constraint], ) # if we are only using 1 thread, it makes sense to do more dofs and skip the atomics. @@ -2441,9 +2391,7 @@ def kernel( nefc = nefc_in[worldid] - sum_val = wp.tile_load( - qM_in[worldid], shape=(nv_pad, nv_pad), bounds_check=True - ) + sum_val = wp.tile_load(qM_in[worldid], shape=(nv_pad, nv_pad), bounds_check=True) # Each tile processes one output tile by looping over all constraints for k in range(0, njmax, TILE_SIZE_K): @@ -2453,12 +2401,7 @@ def kernel( # AD: leaving bounds-check disabled here because I'm not entirely sure that # everything always hits the fast path. The padding takes care of any # potential OOB accesses. - J_kj = wp.tile_load( - efc_J_in[worldid], - shape=(TILE_SIZE_K, nv_pad), - offset=(k, 0), - bounds_check=False, - ) + J_kj = wp.tile_load(efc_J_in[worldid], shape=(TILE_SIZE_K, nv_pad), offset=(k, 0), bounds_check=False) # state check D_k = wp.tile_load(efc_D_in[worldid], shape=TILE_SIZE_K, offset=k, bounds_check=False) @@ -2473,11 +2416,7 @@ def kernel( active_tile = wp.tile_map(active_check, tid_tile, threshold_tile) D_k = wp.tile_map(wp.mul, active_tile, D_k) - J_ki = wp.tile_map( - wp.mul, - wp.tile_transpose(J_kj), - wp.tile_broadcast(D_k, shape=(nv_pad, TILE_SIZE_K)), - ) + J_ki = wp.tile_map(wp.mul, wp.tile_transpose(J_kj), wp.tile_broadcast(D_k, shape=(nv_pad, TILE_SIZE_K))) sum_val += wp.tile_matmul(J_ki, J_kj) @@ -2489,32 +2428,32 @@ def kernel( # TODO(thowell): combine with JTDAJ ? @wp.kernel def update_gradient_JTCJ_sparse( - # Model: - opt_impratio_invsqrt: wp.array(dtype=float), - dof_tri_row: wp.array(dtype=int), - dof_tri_col: wp.array(dtype=int), - # Data in: - contact_dist_in: wp.array(dtype=float), - contact_includemargin_in: wp.array(dtype=float), - contact_friction_in: wp.array(dtype=types.vec5), - contact_dim_in: wp.array(dtype=int), - contact_efc_address_in: wp.array2d(dtype=int), - contact_worldid_in: wp.array(dtype=int), - efc_J_rownnz_in: wp.array2d(dtype=int), - efc_J_rowadr_in: wp.array2d(dtype=int), - efc_J_colind_in: wp.array3d(dtype=int), - efc_J_in: wp.array3d(dtype=float), - efc_D_in: wp.array2d(dtype=float), - efc_state_in: wp.array2d(dtype=int), - naconmax_in: int, - nacon_in: wp.array(dtype=int), - # In: - ctx_Jaref_in: wp.array2d(dtype=float), - ctx_done_in: wp.array(dtype=bool), - nblocks_perblock: int, - dim_block: int, - # Out: - h_out: wp.array3d(dtype=float), + # Model: + opt_impratio_invsqrt: wp.array(dtype=float), + dof_tri_row: wp.array(dtype=int), + dof_tri_col: wp.array(dtype=int), + # Data in: + contact_dist_in: wp.array(dtype=float), + contact_includemargin_in: wp.array(dtype=float), + contact_friction_in: wp.array(dtype=types.vec5), + contact_dim_in: wp.array(dtype=int), + contact_efc_address_in: wp.array2d(dtype=int), + contact_worldid_in: wp.array(dtype=int), + efc_J_rownnz_in: wp.array2d(dtype=int), + efc_J_rowadr_in: wp.array2d(dtype=int), + efc_J_colind_in: wp.array3d(dtype=int), + efc_J_in: wp.array3d(dtype=float), + efc_D_in: wp.array2d(dtype=float), + efc_state_in: wp.array2d(dtype=int), + naconmax_in: int, + nacon_in: wp.array(dtype=int), + # In: + ctx_Jaref_in: wp.array2d(dtype=float), + ctx_done_in: wp.array(dtype=bool), + nblocks_perblock: int, + dim_block: int, + # Out: + ctx_h_out: wp.array3d(dtype=float), ): conid_start, elementid = wp.tid() @@ -2529,20 +2468,37 @@ def update_gradient_JTCJ_sparse( worldid = contact_worldid_in[conid] if ctx_done_in[worldid]: - return + continue condim = contact_dim_in[conid] if condim == 1: - return + continue # check contact status if contact_dist_in[conid] - contact_includemargin_in[conid] >= 0.0: - return + continue efcid0 = contact_efc_address_in[conid, 0] if efc_state_in[worldid, efcid0] != types.ConstraintState.CONE: - return + continue + + # All dims share the same sparsity pattern. Scan colind once to find + # the sparse positions of dof1id and dof2id. Skip if either is absent. + rownnz = efc_J_rownnz_in[worldid, efcid0] + rowadr0 = efc_J_rowadr_in[worldid, efcid0] + pos1 = int(-1) + pos2 = int(-1) + for k in range(rownnz): + col = efc_J_colind_in[worldid, 0, rowadr0 + k] + if col == dof1id: + pos1 = k + if col == dof2id: + pos2 = k + if pos1 >= 0 and pos2 >= 0: + break + if pos1 < 0 or pos2 < 0: + continue fri = contact_friction_in[conid] mu = fri[0] * opt_impratio_invsqrt[worldid % opt_impratio_invsqrt.shape[0]] @@ -2551,7 +2507,7 @@ def update_gradient_JTCJ_sparse( dm = math.safe_div(efc_D_in[worldid, efcid0], mu2 * (1.0 + mu2)) if dm == 0.0: - return + continue n = ctx_Jaref_in[worldid, efcid0] * mu u = types.vec6(n, 0.0, 0.0, 0.0, 0.0, 0.0) @@ -2570,52 +2526,40 @@ def update_gradient_JTCJ_sparse( t = wp.max(t, types.MJ_MINVAL) ttt = wp.max(t * t * t, types.MJ_MINVAL) + # Precompute common subexpressions. + mu_over_t = math.safe_div(mu, t) + mu_n_over_ttt = mu * math.safe_div(n, ttt) + mu2_minus_mu_n_over_t = mu2 - mu * math.safe_div(n, t) + h = float(0.0) for dim1id in range(condim): if dim1id == 0: - efcid1 = efcid0 + rowadr1 = rowadr0 + dm_fri1 = dm * mu else: efcid1 = contact_efc_address_in[conid, dim1id] + rowadr1 = efc_J_rowadr_in[worldid, efcid1] + dm_fri1 = dm * fri[dim1id - 1] - # TODO(team): improve performance for sparse code path - rownnz1 = efc_J_rownnz_in[worldid, efcid1] - rowadr1 = efc_J_rowadr_in[worldid, efcid1] - - efc_J11 = float(0.0) - efc_J12 = float(0.0) - for i1 in range(rownnz1): - sparseid1 = rowadr1 + i1 - colind1 = efc_J_colind_in[worldid, 0, sparseid1] - if dof1id == colind1: - efc_J11 = efc_J_in[worldid, 0, sparseid1] - if dof2id == colind1: - efc_J12 = efc_J_in[worldid, 0, sparseid1] - if efc_J11 != 0.0 and efc_J12 != 0.0: - break + # Direct J reads using cached sparse positions. + efc_J11 = efc_J_in[worldid, 0, rowadr1 + pos1] + efc_J12 = efc_J_in[worldid, 0, rowadr1 + pos2] ui = u[dim1id] for dim2id in range(0, dim1id + 1): if dim2id == 0: - efcid2 = efcid0 + rowadr2 = rowadr0 + dm_fri12 = dm_fri1 * mu else: efcid2 = contact_efc_address_in[conid, dim2id] + rowadr2 = efc_J_rowadr_in[worldid, efcid2] + dm_fri12 = dm_fri1 * fri[dim2id - 1] - rownnz2 = efc_J_rownnz_in[worldid, efcid2] - rowadr2 = efc_J_rowadr_in[worldid, efcid2] - - efc_J21 = float(0.0) - efc_J22 = float(0.0) - for i2 in range(rownnz2): - sparseid2 = rowadr2 + i2 - colind2 = efc_J_colind_in[worldid, 0, sparseid2] - if dof1id == colind2: - efc_J21 = efc_J_in[worldid, 0, sparseid2] - if dof2id == colind2: - efc_J22 = efc_J_in[worldid, 0, sparseid2] - if efc_J21 != 0.0 and efc_J22 != 0.0: - break + # Direct J reads using cached sparse positions. + efc_J21 = efc_J_in[worldid, 0, rowadr2 + pos1] + efc_J22 = efc_J_in[worldid, 0, rowadr2 + pos2] uj = u[dim2id] @@ -2623,28 +2567,17 @@ def update_gradient_JTCJ_sparse( if dim1id == 0 and dim2id == 0: hcone = 1.0 elif dim1id == 0: - hcone = -math.safe_div(mu, t) * uj + hcone = -mu_over_t * uj elif dim2id == 0: - hcone = -math.safe_div(mu, t) * ui + hcone = -mu_over_t * ui else: - hcone = mu * math.safe_div(n, ttt) * ui * uj + hcone = mu_n_over_ttt * ui * uj # add to diagonal: mu^2 - mu * n / t if dim1id == dim2id: - hcone += mu2 - mu * math.safe_div(n, t) + hcone += mu2_minus_mu_n_over_t - # pre and post multiply by diag(mu, friction) scale by dm - if dim1id == 0: - fri1 = mu - else: - fri1 = fri[dim1id - 1] - - if dim2id == 0: - fri2 = mu - else: - fri2 = fri[dim2id - 1] - - hcone *= dm * fri1 * fri2 + hcone *= dm_fri12 if hcone != 0.0: h += hcone * efc_J11 * efc_J22 @@ -2652,34 +2585,34 @@ def update_gradient_JTCJ_sparse( if dim1id != dim2id: h += hcone * efc_J12 * efc_J21 - h_out[worldid, dof1id, dof2id] += h + ctx_h_out[worldid, dof1id, dof2id] += h @wp.kernel def update_gradient_JTCJ_dense( - # Model: - opt_impratio_invsqrt: wp.array(dtype=float), - dof_tri_row: wp.array(dtype=int), - dof_tri_col: wp.array(dtype=int), - # Data in: - contact_dist_in: wp.array(dtype=float), - contact_includemargin_in: wp.array(dtype=float), - contact_friction_in: wp.array(dtype=types.vec5), - contact_dim_in: wp.array(dtype=int), - contact_efc_address_in: wp.array2d(dtype=int), - contact_worldid_in: wp.array(dtype=int), - efc_J_in: wp.array3d(dtype=float), - efc_D_in: wp.array2d(dtype=float), - efc_state_in: wp.array2d(dtype=int), - naconmax_in: int, - nacon_in: wp.array(dtype=int), - # In: - ctx_Jaref_in: wp.array2d(dtype=float), - ctx_done_in: wp.array(dtype=bool), - nblocks_perblock: int, - dim_block: int, - # Out: - ctx_h_out: wp.array3d(dtype=float), + # Model: + opt_impratio_invsqrt: wp.array(dtype=float), + dof_tri_row: wp.array(dtype=int), + dof_tri_col: wp.array(dtype=int), + # Data in: + contact_dist_in: wp.array(dtype=float), + contact_includemargin_in: wp.array(dtype=float), + contact_friction_in: wp.array(dtype=types.vec5), + contact_dim_in: wp.array(dtype=int), + contact_efc_address_in: wp.array2d(dtype=int), + contact_worldid_in: wp.array(dtype=int), + efc_J_in: wp.array3d(dtype=float), + efc_D_in: wp.array2d(dtype=float), + efc_state_in: wp.array2d(dtype=int), + naconmax_in: int, + nacon_in: wp.array(dtype=int), + # In: + ctx_Jaref_in: wp.array2d(dtype=float), + ctx_done_in: wp.array(dtype=bool), + nblocks_perblock: int, + dim_block: int, + # Out: + ctx_h_out: wp.array3d(dtype=float), ): conid_start, elementid = wp.tid() @@ -2863,40 +2796,86 @@ def padding_h(nv: int, ctx_done_in: wp.array(dtype=bool), ctx_h_out: wp.array3d( ctx_h_out[worldid, dofid, dofid] = 1.0 -def _cholesky_factorize_solve( - m: types.Model, d: types.Data, ctx: SolverContext -): +def _cholesky_factorize_solve(m: types.Model, d: types.Data, ctx: SolverContext): """Cholesky factorize ctx.h and solve for Mgrad.""" if m.nv <= _BLOCK_CHOLESKY_DIM: wp.launch_tiled( - update_gradient_cholesky(m.nv), - dim=d.nworld, - inputs=[ctx.grad, ctx.h, ctx.done], - outputs=[ctx.Mgrad], - block_dim=m.block_dim.update_gradient_cholesky, + update_gradient_cholesky(m.nv), + dim=d.nworld, + inputs=[ctx.grad, ctx.h, ctx.done], + outputs=[ctx.Mgrad], + block_dim=m.block_dim.update_gradient_cholesky, ) else: wp.launch( - padding_h, - dim=(d.nworld, m.nv_pad - m.nv), - inputs=[m.nv, ctx.done], - outputs=[ctx.h], + padding_h, + dim=(d.nworld, m.nv_pad - m.nv), + inputs=[m.nv, ctx.done], + outputs=[ctx.h], ) wp.launch_tiled( - update_gradient_cholesky_blocked(types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad), - dim=d.nworld, - inputs=[ - ctx.done, - ctx.grad.reshape(shape=(d.nworld, ctx.grad.shape[1], 1)), - ctx.h, - ctx.hfactor, - ], - outputs=[ctx.Mgrad.reshape(shape=(d.nworld, ctx.Mgrad.shape[1], 1))], - block_dim=m.block_dim.update_gradient_cholesky_blocked, + update_gradient_cholesky_blocked(types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad), + dim=d.nworld, + inputs=[ctx.done, ctx.grad.reshape(shape=(d.nworld, ctx.grad.shape[1], 1)), ctx.h, ctx.hfactor], + outputs=[ctx.Mgrad.reshape(shape=(d.nworld, ctx.Mgrad.shape[1], 1))], + block_dim=m.block_dim.update_gradient_cholesky_blocked, ) +@wp.kernel +def _JTDAJ_sparse( + # Data in: + nefc_in: wp.array(dtype=int), + efc_J_rownnz_in: wp.array2d(dtype=int), + efc_J_rowadr_in: wp.array2d(dtype=int), + efc_J_colind_in: wp.array3d(dtype=int), + efc_J_in: wp.array3d(dtype=float), + efc_D_in: wp.array2d(dtype=float), + efc_state_in: wp.array2d(dtype=int), + # In: + ctx_done_in: wp.array(dtype=bool), + # Out: + h_out: wp.array3d(dtype=float), +): + worldid, efcid = wp.tid() + + if ctx_done_in[worldid]: + return + + if efcid >= nefc_in[worldid]: + return + + efc_D = efc_D_in[worldid, efcid] + efc_state = efc_state_in[worldid, efcid] + + if state_check(efc_D, efc_state) == 0.0: + return + + rownnz = efc_J_rownnz_in[worldid, efcid] + rowadr = efc_J_rowadr_in[worldid, efcid] + + for i in range(rownnz): + sparseidi = rowadr + i + Ji = efc_J_in[worldid, 0, sparseidi] + colindi = efc_J_colind_in[worldid, 0, sparseidi] + for j in range(i, rownnz): + if j == i: + sparseidj = sparseidi + Jj = Ji + colindj = colindi + else: + sparseidj = rowadr + j + Jj = efc_J_in[worldid, 0, sparseidj] + colindj = efc_J_colind_in[worldid, 0, sparseidj] + + h = Ji * Jj * efc_D + wp.atomic_add(h_out[worldid, colindi], colindj, h) + + if i != j: + wp.atomic_add(h_out[worldid, colindj], colindi, h) + + def _update_gradient(m: types.Model, d: types.Data, ctx: SolverContext): # grad = Ma - qfrc_smooth - qfrc_constraint wp.launch(update_gradient_zero_grad_dot, dim=(d.nworld), inputs=[ctx.done], outputs=[ctx.grad_dot]) @@ -2912,126 +2891,15 @@ def _update_gradient(m: types.Model, d: types.Data, ctx: SolverContext): smooth.solve_m(m, d, ctx.Mgrad, ctx.grad) elif m.opt.solver == types.SolverType.NEWTON: # h = qM + (efc_J.T * efc_D * active) @ efc_J - if SPARSE_CONSTRAINT_JACOBIAN: - # TODO(team): improve performance for sparse code path - @wp.kernel(module="unique", enable_backward=False) - def _JTDAJ_sparse( - # Data in: - nefc_in: wp.array(dtype=int), - efc_J_rownnz_in: wp.array2d(dtype=int), - efc_J_rowadr_in: wp.array2d(dtype=int), - efc_J_colind_in: wp.array3d(dtype=int), - efc_J_in: wp.array3d(dtype=float), - efc_D_in: wp.array2d(dtype=float), - efc_state_in: wp.array2d(dtype=int), - # In: - ctx_done_in: wp.array(dtype=bool), - # Out: - h_out: wp.array3d(dtype=float), - ): - worldid, efcid = wp.tid() - - if ctx_done_in[worldid]: - return - - if efcid >= nefc_in[worldid]: - return - - efc_D = efc_D_in[worldid, efcid] - efc_state = efc_state_in[worldid, efcid] - - if state_check(efc_D, efc_state) == 0.0: - return - - rownnz = efc_J_rownnz_in[worldid, efcid] - rowadr = efc_J_rowadr_in[worldid, efcid] - - for i in range(rownnz): - sparseidi = rowadr + i - Ji = efc_J_in[worldid, 0, sparseidi] - colindi = efc_J_colind_in[worldid, 0, sparseidi] - for j in range(i, rownnz): - if j == i: - sparseidj = sparseidi - Jj = Ji - colindj = colindi - else: - sparseidj = rowadr + j - Jj = efc_J_in[worldid, 0, sparseidj] - colindj = efc_J_colind_in[worldid, 0, sparseidj] - - h = Ji * Jj * efc_D - wp.atomic_add(h_out[worldid, colindi], colindj, h) - - if i != j: - wp.atomic_add(h_out[worldid, colindj], colindi, h) - + if m.is_sparse: + ctx.h.zero_() wp.launch( - _JTDAJ_sparse, - dim=(d.nworld, d.njmax), - inputs=[ - d.nefc, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.D, - d.efc.state, - ctx.done, - ], - outputs=[ctx.h], + _JTDAJ_sparse, + dim=(d.nworld, d.njmax), + inputs=[d.nefc, d.efc.J_rownnz, d.efc.J_rowadr, d.efc.J_colind, d.efc.J, d.efc.D, d.efc.state, ctx.done], + outputs=[ctx.h], ) - if m.is_sparse: - wp.launch( - update_gradient_set_h_qM_lower_sparse, - dim=(d.nworld, m.qM_fullm_i.size), - inputs=[m.qM_fullm_i, m.qM_fullm_j, d.qM, ctx.done], - outputs=[ctx.h], - ) - else: - # dense M: copy qM directly into h - @wp.kernel(module="unique", enable_backward=False) - def _set_h_qM_dense( - nv: int, - qM_in: wp.array3d(dtype=float), - ctx_done_in: wp.array(dtype=bool), - ctx_h_out: wp.array3d(dtype=float), - ): - worldid, i, j = wp.tid() - if ctx_done_in[worldid]: - return - if i >= nv or j >= nv: - return - if i >= j: - ctx_h_out[worldid, i, j] += qM_in[worldid, i, j] - - wp.launch( - _set_h_qM_dense, - dim=(d.nworld, m.nv, m.nv), - inputs=[m.nv, d.qM, ctx.done], - outputs=[ctx.h], - ) - elif m.is_sparse: - num_blocks_ceil = ceil(m.nv / types.TILE_SIZE_JTDAJ_SPARSE) - lower_triangle_dim = int(num_blocks_ceil * (num_blocks_ceil + 1) / 2) - with scoped_mathdx_gemm_disabled(): - wp.launch_tiled( - update_gradient_JTDAJ_sparse_tiled( - types.TILE_SIZE_JTDAJ_SPARSE, d.njmax - ), - dim=(d.nworld, lower_triangle_dim), - inputs=[ - d.nefc, - d.efc.J, - d.efc.D, - d.efc.state, - ctx.done, - ], - outputs=[ctx.h], - block_dim=m.block_dim.update_gradient_JTDAJ_sparse, - ) - wp.launch( update_gradient_set_h_qM_lower_sparse, dim=(d.nworld, m.qM_fullm_i.size), @@ -3041,20 +2909,18 @@ def _set_h_qM_dense( else: with scoped_mathdx_gemm_disabled(): wp.launch_tiled( - update_gradient_JTDAJ_dense_tiled( - m.nv_pad, types.TILE_SIZE_JTDAJ_DENSE, d.njmax - ), - dim=d.nworld, - inputs=[ - d.nefc, - d.qM, - d.efc.J, - d.efc.D, - d.efc.state, - ctx.done, - ], - outputs=[ctx.h], - block_dim=m.block_dim.update_gradient_JTDAJ_dense, + update_gradient_JTDAJ_dense_tiled(m.nv_pad, types.TILE_SIZE_JTDAJ_DENSE, d.njmax), + dim=d.nworld, + inputs=[ + d.nefc, + d.qM, + d.efc.J, + d.efc.D, + d.efc.state, + ctx.done, + ], + outputs=[ctx.h], + block_dim=m.block_dim.update_gradient_JTDAJ_dense, ) if m.opt.cone == types.ConeType.ELLIPTIC: @@ -3081,60 +2947,60 @@ def _set_h_qM_dense( nblocks_perblock = int((d.naconmax + dim_block - 1) / dim_block) - if SPARSE_CONSTRAINT_JACOBIAN: + if m.is_sparse: wp.launch( - update_gradient_JTCJ_sparse, - dim=(d.naconmax, m.dof_tri_row.size), - inputs=[ - m.opt.impratio_invsqrt, - m.dof_tri_row, - m.dof_tri_col, - d.contact.dist, - d.contact.includemargin, - d.contact.friction, - d.contact.dim, - d.contact.efc_address, - d.contact.worldid, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.D, - d.efc.state, - d.naconmax, - d.nacon, - ctx.Jaref, - ctx.done, - nblocks_perblock, - dim_block, - ], - outputs=[ctx.h], + update_gradient_JTCJ_sparse, + dim=(dim_block, m.dof_tri_row.size), + inputs=[ + m.opt.impratio_invsqrt, + m.dof_tri_row, + m.dof_tri_col, + d.contact.dist, + d.contact.includemargin, + d.contact.friction, + d.contact.dim, + d.contact.efc_address, + d.contact.worldid, + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.D, + d.efc.state, + d.naconmax, + d.nacon, + ctx.Jaref, + ctx.done, + nblocks_perblock, + dim_block, + ], + outputs=[ctx.h], ) else: wp.launch( - update_gradient_JTCJ_dense, - dim=(dim_block, m.dof_tri_row.size), - inputs=[ - m.opt.impratio_invsqrt, - m.dof_tri_row, - m.dof_tri_col, - d.contact.dist, - d.contact.includemargin, - d.contact.friction, - d.contact.dim, - d.contact.efc_address, - d.contact.worldid, - d.efc.J, - d.efc.D, - d.efc.state, - d.naconmax, - d.nacon, - ctx.Jaref, - ctx.done, - nblocks_perblock, - dim_block, - ], - outputs=[ctx.h], + update_gradient_JTCJ_dense, + dim=(dim_block, m.dof_tri_row.size), + inputs=[ + m.opt.impratio_invsqrt, + m.dof_tri_row, + m.dof_tri_col, + d.contact.dist, + d.contact.includemargin, + d.contact.friction, + d.contact.dim, + d.contact.efc_address, + d.contact.worldid, + d.efc.J, + d.efc.D, + d.efc.state, + d.naconmax, + d.nacon, + ctx.Jaref, + ctx.done, + nblocks_perblock, + dim_block, + ], + outputs=[ctx.h], ) _cholesky_factorize_solve(m, d, ctx) @@ -3142,58 +3008,51 @@ def _set_h_qM_dense( raise ValueError(f"Unknown solver type: {m.opt.solver}") -def _update_gradient_incremental( - m: types.Model, d: types.Data, ctx: SolverContext -): +def _update_gradient_incremental(m: types.Model, d: types.Data, ctx: SolverContext): """Incremental gradient update: update H for changed constraints + re-factorize. Skips the full J^T*D*J rebuild by applying only the delta from constraints that changed QUADRATIC state, then re-factorizes and solves. """ - wp.launch( - update_gradient_zero_grad_dot, - dim=(d.nworld), - inputs=[ctx.done], - outputs=[ctx.grad_dot], - ) + wp.launch(update_gradient_zero_grad_dot, dim=(d.nworld), inputs=[ctx.done], outputs=[ctx.grad_dot]) wp.launch( - update_gradient_grad, - dim=(d.nworld, m.nv), - inputs=[d.qfrc_smooth, d.qfrc_constraint, d.efc.Ma, ctx.done], - outputs=[ctx.grad, ctx.grad_dot], + update_gradient_grad, + dim=(d.nworld, m.nv), + inputs=[d.qfrc_smooth, d.qfrc_constraint, d.efc.Ma, ctx.done], + outputs=[ctx.grad, ctx.grad_dot], ) # Update lower triangle of H with delta from changed constraints - if SPARSE_CONSTRAINT_JACOBIAN: + if m.is_sparse: wp.launch( - update_gradient_h_incremental_sparse, - dim=(d.nworld, ctx.changed_efc_ids.shape[1]), - inputs=[ - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.D, - d.efc.state, - ctx.changed_efc_ids, - ctx.changed_efc_count, - ], - outputs=[ctx.h], + update_gradient_h_incremental_sparse, + dim=(d.nworld, ctx.changed_efc_ids.shape[1]), + inputs=[ + d.efc.J_rownnz, + d.efc.J_rowadr, + d.efc.J_colind, + d.efc.J, + d.efc.D, + d.efc.state, + ctx.changed_efc_ids, + ctx.changed_efc_count, + ], + outputs=[ctx.h], ) else: lower_tri_dim = m.nv * (m.nv + 1) // 2 wp.launch( - update_gradient_h_incremental, - dim=(d.nworld, lower_tri_dim), - inputs=[ - d.efc.J, - d.efc.D, - d.efc.state, - ctx.changed_efc_ids, - ctx.changed_efc_count, - ], - outputs=[ctx.h], + update_gradient_h_incremental, + dim=(d.nworld, lower_tri_dim), + inputs=[ + d.efc.J, + d.efc.D, + d.efc.state, + ctx.changed_efc_ids, + ctx.changed_efc_count, + ], + outputs=[ctx.h], ) _cholesky_factorize_solve(m, d, ctx) @@ -3347,10 +3206,7 @@ def _solver_iteration( # path in update_constraint_efc has early returns that skip state change # tracking, and the additional JTCJ Hessian term depends on Jaref which # changes every iteration. - incremental = ( - m.opt.solver == types.SolverType.NEWTON - and m.opt.cone != types.ConeType.ELLIPTIC - ) + incremental = m.opt.solver == types.SolverType.NEWTON and m.opt.cone != types.ConeType.ELLIPTIC if incremental: # Must complete before update_constraint_efc which atomically increments. @@ -3422,18 +3278,10 @@ def init_context(m: types.Model, d: types.Data, ctx: SolverContext | InverseCont ctx.Jaref.zero_() wp.launch( - solve_init_jaref(m.is_sparse, m.nv, dofs_per_thread), - dim=(d.nworld, d.njmax, threads_per_efc), - inputs=[ - d.nefc, - d.qacc, - d.efc.J_rownnz, - d.efc.J_rowadr, - d.efc.J_colind, - d.efc.J, - d.efc.aref, - ], - outputs=[ctx.Jaref], + solve_init_jaref(m.is_sparse, m.nv, dofs_per_thread), + dim=(d.nworld, d.njmax, threads_per_efc), + inputs=[d.nefc, d.qacc, d.efc.J_rownnz, d.efc.J_rowadr, d.efc.J_colind, d.efc.J, d.efc.aref], + outputs=[ctx.Jaref], ) # Ma = qM @ qacc diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/support.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/support.py index b45472c459..995b2a7c84 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/support.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/support.py @@ -18,18 +18,52 @@ import warp as wp from mujoco.mjx.third_party.mujoco_warp._src.math import motion_cross +from mujoco.mjx.third_party.mujoco_warp._src.types import MJ_MINVAL from mujoco.mjx.third_party.mujoco_warp._src.types import ConeType from mujoco.mjx.third_party.mujoco_warp._src.types import Data +from mujoco.mjx.third_party.mujoco_warp._src.types import DynType from mujoco.mjx.third_party.mujoco_warp._src.types import JointType from mujoco.mjx.third_party.mujoco_warp._src.types import Model from mujoco.mjx.third_party.mujoco_warp._src.types import State from mujoco.mjx.third_party.mujoco_warp._src.types import vec5 +from mujoco.mjx.third_party.mujoco_warp._src.types import vec10f from mujoco.mjx.third_party.mujoco_warp._src.warp_util import cache_kernel from mujoco.mjx.third_party.mujoco_warp._src.warp_util import event_scope wp.set_module_options({"enable_backward": False}) +# TODO(team): kernel analyzer array slice? +@wp.func +def next_act( + # Model: + opt_timestep: float, # kernel_analyzer: ignore + actuator_dyntype: int, # kernel_analyzer: ignore + actuator_dynprm: vec10f, # kernel_analyzer: ignore + actuator_actrange: wp.vec2, # kernel_analyzer: ignore + # Data In: + act_in: float, # kernel_analyzer: ignore + act_dot_in: float, # kernel_analyzer: ignore + # In: + act_dot_scale: float, + clamp: bool, +) -> float: + # advance actuation + if actuator_dyntype == DynType.FILTEREXACT: + tau = wp.max(MJ_MINVAL, actuator_dynprm[0]) + act = act_in + act_dot_scale * act_dot_in * tau * (1.0 - wp.exp(-opt_timestep / tau)) + elif actuator_dyntype == DynType.USER: + return act_in + else: + act = act_in + act_dot_scale * act_dot_in * opt_timestep + + # clamp to actrange + if clamp: + act = wp.clamp(act, actuator_actrange[0], actuator_actrange[1]) + + return act + + @cache_kernel def mul_m_sparse(check_skip: bool): @wp.kernel(module="unique") diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/types.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/types.py index 94434ff9ef..11b7a7c090 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/types.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/types.py @@ -33,11 +33,8 @@ TILE_SIZE_JTDAJ_SPARSE = 16 TILE_SIZE_JTDAJ_DENSE = 16 -# TODO(team): remove after improving performance for sparse constraint jacobian -SPARSE_CONSTRAINT_JACOBIAN = False - -# TODO(team): remove after mjwarp depends on warp-lang >= 1.12 in pyproject.toml -TEXTURE_DTYPE = wp.Texture2D if hasattr(wp, "Texture2D") else int +# maximum number of plugin attributes +_NPLUGINATTR = 128 # TODO(team): add check that all wp.launch_tiled 'block_dim' settings are configurable @@ -53,7 +50,6 @@ class BlockDim: # forward euler_dense: int = 32 actuator_velocity: int = 32 - tendon_velocity: int = 32 # ray ray: int = 64 # sensor @@ -63,6 +59,7 @@ class BlockDim: cholesky_factorize: int = 32 cholesky_solve: int = 32 cholesky_factorize_solve: int = 32 + solve_LD_sparse_fused: int = 64 # solver update_gradient_cholesky: int = 64 update_gradient_cholesky_blocked: int = 32 @@ -351,6 +348,7 @@ class GeomType(enum.IntEnum): BOX: box MESH: mesh SDF: sdf + FLEX: flex """ PLANE = mujoco.mjtGeom.mjGEOM_PLANE @@ -362,6 +360,7 @@ class GeomType(enum.IntEnum): BOX = mujoco.mjtGeom.mjGEOM_BOX MESH = mujoco.mjtGeom.mjGEOM_MESH SDF = mujoco.mjtGeom.mjGEOM_SDF + FLEX = mujoco.mjtGeom.mjGEOM_FLEX # unsupported: NGEOMTYPES, ARROW*, LINE, SKIN, LABEL, NONE @@ -662,6 +661,10 @@ class vec11f(wp.types.vector(length=11, dtype=float)): pass +class vec_pluginattr(wp.types.vector(length=_NPLUGINATTR, dtype=float)): + pass + + class mat23f(wp.types.matrix(shape=(2, 3), dtype=float)): pass @@ -679,6 +682,7 @@ class mat63f(wp.types.matrix(shape=(6, 3), dtype=float)): vec8 = vec8f vec10 = vec10f vec11 = vec11f +vec128 = vec_pluginattr mat23 = mat23f mat43 = mat43f mat63 = mat63f @@ -841,6 +845,7 @@ class Model: nflexelem: number of elements in all flexes nflexelemdata: number of element vertex ids in all flexes nflexelemedge: number of element edge ids in all flexes + nflexshelldata: number of shell fragment vertex ids in all flexes nJfe: number of non-zeros in sparse flexedge Jacobian nmesh: number of meshes nmeshvert: number of vertices for all meshes @@ -857,6 +862,7 @@ class Model: nexclude: number of excluded geom pairs neq: number of equality constraints ntendon: number of tendons + nJten: number of non-zeros in sparse tendon Jacobian nwrap: number of wrap objects in all tendon paths nsensor: number of sensors nmocap: number of mocap bodies @@ -973,6 +979,11 @@ class Model: light_poscom0: global position rel. to sub-com in qpos0 (*, nlight, 3) light_pos0: global position rel. to body in qpos0 (*, nlight, 3) light_dir0: global direction in qpos0 (*, nlight, 3) + flex_contype: flex contact type (nflex,) + flex_conaffinity: flex contact affinity (nflex,) + flex_condim: contact dimensionality (1, 3, 4, 6) (nflex,) + flex_friction: friction for (slide, spin, roll) (nflex, 3) + flex_margin: detect contact if dist= 1.12 in pyproject.toml - textures: array("*", TEXTURE_DTYPE) - textures_registry: list[TEXTURE_DTYPE] + textures: array("*", wp.Texture2D) + textures_registry: list[wp.Texture2D] hfield_registry: dict hfield_bvh_id: array("nhfield", wp.uint64) hfield_bounds_size: array("nhfield", wp.vec3) - flex_mesh: wp.Mesh + flex_mesh_registry: dict flex_rgba: array("nflex", wp.vec4) - flex_bvh_id: wp.uint64 - flex_face_point: array("*", wp.vec3) - flex_faceadr: array("nflex", int) - flex_nface: int - flex_nwork: int - flex_group_root: array("nworld", int) - flex_elemdataadr: array("nflex", int) - flex_shell: array("*", int) - flex_shelldataadr: array("nflex", int) - flex_radius: array("nflex", float) - flex_workadr: array("nflex", int) - flex_worknum: array("nflex", int) + flex_bvh_id: array("*", wp.uint64) + flex_group_root: array("nworld", "*", int) flex_render_smooth: bool + bvh_nflexgeom: int + flex_dim_np: array("nflex", int) + flex_geom_flexid: array("*", int) + flex_geom_edgeid: array("*", int) bvh: wp.Bvh bvh_id: wp.uint64 lower: array("*", wp.vec3) @@ -1988,5 +1978,8 @@ class RenderContext: depth_adr: array("ncam", int) render_rgb: array("ncam", bool) render_depth: array("ncam", bool) + seg_data: array("*", int) + seg_adr: array("ncam", int) + render_seg: array("ncam", bool) znear: float total_rays: int diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/util_pkg.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/util_pkg.py index b7debea4d3..e2f8acf9cd 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/util_pkg.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/util_pkg.py @@ -37,9 +37,7 @@ def _parse_version(version_str: str) -> tuple[tuple[int, int | str], ...]: """ # Split on both '.' and '-' parts = re.split(r"[.\-]", version_str) - return tuple( - [(0, int(p)) if p.isdigit() else (-1, p) for p in parts] + [(0, 0)] - ) + return tuple([(0, int(p)) if p.isdigit() else (-1, p) for p in parts] + [(0, 0)]) def check_version(spec: str) -> bool: @@ -65,9 +63,7 @@ def check_version(spec: str) -> bool: """ match = re.match(r"^([a-zA-Z0-9_\-]+)(>=|<=|>|<|==|!=)(.+)$", spec) if not match: - raise ValueError( - f"Invalid version spec '{spec}'. Expected format: 'package>=version'" - ) + raise ValueError(f"Invalid version spec '{spec}'. Expected format: 'package>=version'") package_name, op, version_str = match.groups() required_version = _parse_version(version_str) @@ -87,11 +83,11 @@ def check_version(spec: str) -> bool: installed_version = _parse_version(installed_str) ops = { - ">=": operator.ge, - "<=": operator.le, - ">": operator.gt, - "<": operator.lt, - "==": operator.eq, - "!=": operator.ne, + ">=": operator.ge, + "<=": operator.le, + ">": operator.gt, + "<": operator.lt, + "==": operator.eq, + "!=": operator.ne, } return ops[op](installed_version, required_version) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/warp_util.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/warp_util.py index 4ac1cb0de6..960e426629 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/warp_util.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/warp_util.py @@ -146,13 +146,13 @@ def check_toolkit_driver(): if wp.get_device().is_cuda: if not wp.is_conditional_graph_supported(): warnings.warn( - """ + """ CUDA version < 12.4 detected - graph capture may be unreliable for < 12.3 - conditional graph nodes are not available for < 12.4 Model.opt.graph_conditional should be set to False """, - stacklevel=2, + stacklevel=2, ) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/pyproject.toml b/mjx/mujoco/mjx/third_party/mujoco_warp/pyproject.toml index 373052487b..e8d98eb460 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/pyproject.toml +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/pyproject.toml @@ -28,7 +28,7 @@ requires-python = ">=3.10" dependencies = [ "absl-py", "etils[epath]", - "mujoco>=3.5.0", + "mujoco>=3.6.0", "numpy", "warp-lang>=1.12", ] @@ -55,7 +55,7 @@ dev = [ "ruff", "pygls>=1.0.0,<2.0.0", "lsprotocol>=2023.0.1,<2024.0.0", - "mujoco>=3.5.0.dev0", + "mujoco>=3.6.0.dev0", "warp-lang>=1.11.0.dev0", ] # TODO(team): cpu and cuda JAX optional dependencies are temporary, remove after we land MJX:Warp diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/viewer.py b/mjx/mujoco/mjx/third_party/mujoco_warp/viewer.py index 5f1de2d49f..cf65f02aec 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/viewer.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/viewer.py @@ -18,7 +18,7 @@ Usage: mjwarp-viewer [flags] Example: - mjwarp-viewer benchmark/humanoid/humanoid.xml -o "opt.solver=cg" + mjwarp-viewer benchmarks/humanoid/humanoid.xml -o "opt.solver=cg" """ import copy @@ -56,6 +56,7 @@ class EngineOptions(enum.IntEnum): _ENGINE = flags.DEFINE_enum_class("engine", EngineOptions.WARP, EngineOptions, "Simulation engine") _NCONMAX = flags.DEFINE_integer("nconmax", None, "Maximum number of contacts.") _NJMAX = flags.DEFINE_integer("njmax", None, "Maximum number of constraints per world.") +_NJMAX_NNZ = flags.DEFINE_integer("njmax_nnz", None, "Maximum number of non-zeros in constraint Jacobian.") _NCCDMAX = flags.DEFINE_integer("nccdmax", None, "Maximum number of CCD contacts per world.") _OVERRIDE = flags.DEFINE_multi_string("override", [], "Model overrides (notation: foo.bar = baz)", short_name="o") _KEYFRAME = flags.DEFINE_integer("keyframe", 0, "keyframe to initialize simulation.") @@ -149,7 +150,7 @@ def _main(argv: Sequence[str]) -> None: override_model(mjm, _OVERRIDE.value) m = mjw.put_model(mjm) override_model(m, _OVERRIDE.value) - d = mjw.put_data(mjm, mjd, nconmax=_NCONMAX.value, njmax=_NJMAX.value, nccdmax=_NCCDMAX.value) + d = mjw.put_data(mjm, mjd, nconmax=_NCONMAX.value, njmax=_NJMAX.value, njmax_nnz=_NJMAX_NNZ.value, nccdmax=_NCCDMAX.value) graph = _compile_step(m, d) if wp.get_device().is_cuda else None if graph is None: mjw.step(m, d) # warmup step diff --git a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/custom_call.py b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/custom_call.py index b46a107135..0adf643594 100644 --- a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/custom_call.py +++ b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/custom_call.py @@ -19,7 +19,7 @@ import warp as wp from warp._src.context import type_str from warp._src.jax import get_jax_device -from warp._src.types import array_t, launch_bounds_t, strides_from_shape +from warp._src.types import array_t, launch_bounds_t, matches_array_class, strides_from_shape from warp._src.utils import warn _wp_module_name_ = "warp.jax_experimental.custom_call" @@ -340,7 +340,7 @@ def warp_call_lowering(ctx, *args, kernel=None, launch_dims=None): wtype = warg.type rtt = ir.RankedTensorType(actual.type) - if not isinstance(wtype, wp.array): + if not matches_array_class(wtype, wp.array): raise Exception("Only contiguous arrays are supported for Jax kernel arguments") if not base_type_is_compatible(wtype.dtype, rtt.element_type): @@ -364,7 +364,7 @@ def warp_call_lowering(ctx, *args, kernel=None, launch_dims=None): for warg in wp_kernel.adj.args[len(args) :]: wtype = warg.type - if not isinstance(wtype, wp.array): + if not matches_array_class(wtype, wp.array): raise Exception("Only contiguous arrays are supported for Jax kernel arguments") # Infer dimensions from the first input. diff --git a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py index f5c925dd44..e9fe408f47 100644 --- a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py +++ b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py @@ -29,19 +29,23 @@ from warp._src.codegen import get_full_arg_spec, make_full_qualified_name from warp._src.context import CudaMemcpyKind from warp._src.jax import get_jax_device -from warp._src.types import array_t, launch_bounds_t, strides_from_shape, type_size_in_bytes, type_to_warp +from warp._src.types import ( + array_t, + launch_bounds_t, + matches_array_class, + strides_from_shape, + type_size_in_bytes, + type_to_warp, +) from .xla_ffi import * _wp_module_name_ = "warp.jax_experimental.ffi" -# Type alias for differentiable kernel cache key -DiffKernelCacheKey = tuple[Callable, tuple, int, str, tuple[str, ...]] - # Holders for the custom callbacks to keep them alive. -_FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {} -_FFI_DIFF_KERNEL_REGISTRY: dict[DiffKernelCacheKey, Callable] = {} -_FFI_CALLABLE_REGISTRY: dict[str, FfiCallable] = {} +_FFI_KERNEL_REGISTRY: dict[tuple, FfiKernel] = {} +_FFI_DIFF_KERNEL_REGISTRY: dict[tuple, Callable] = {} +_FFI_CALLABLE_REGISTRY: dict[tuple, FfiCallable] = {} _FFI_CALLBACK_REGISTRY: dict[str, ctypes.CFUNCTYPE] = {} _FFI_REGISTRY_LOCK = threading.Lock() @@ -61,6 +65,21 @@ def check_jax_version(): raise RuntimeError(msg) +def collapse_batch_dims(shape, desired_ndim): + # roll leading batch dims into one + while len(shape) > desired_ndim: + shape = (shape[0] * shape[1], *shape[2:]) + return shape + + +def compute_batch_size(shape, batch_ndim): + # compute product of batch dims at front + batch_size = 1 + for i in range(batch_ndim): + batch_size *= shape[i] + return batch_size + + class GraphMode(IntEnum): """CUDA graph capture modes for :func:`warp.jax_experimental.jax_callable`. @@ -91,7 +110,7 @@ def __init__(self, name, type, in_out=False): self.name = name self.type = type self.in_out = in_out - self.is_array = isinstance(type, wp.array) + self.is_array = matches_array_class(type, wp.array) if self.is_array: if hasattr(type.dtype, "_wp_scalar_type_"): @@ -125,7 +144,15 @@ def __init__(self, static_inputs, launch_dims): class FfiKernel: def __init__( - self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames, module_preload_mode + self, + kernel, + num_outputs, + vmap_method, + launch_dims, + output_dims, + in_out_argnames, + module_preload_mode, + has_side_effect=False, ): self.kernel = kernel self.name = generate_unique_name(kernel.func) @@ -134,6 +161,7 @@ def __init__( self.launch_dims = launch_dims self.output_dims = output_dims self.module_preload_mode = module_preload_mode + self.has_side_effect = has_side_effect self.first_array_arg = None self.launch_id = 0 self.launch_descriptors = {} @@ -250,7 +278,8 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): out_types.append(get_jax_output_type(input_arg, input_value.shape)) # launch dimensions - if launch_dims is None: + infer_launch_dims = launch_dims is None + if infer_launch_dims: # use the shape of the first input array if self.first_array_arg is not None: launch_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape) @@ -284,6 +313,7 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): out_types, vmap_method=vmap_method, input_output_aliases=self.input_output_aliases, + has_side_effect=self.has_side_effect, ) # preload on the specified devices @@ -303,7 +333,9 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): # save launch data to be retrieved by callback launch_id = self.launch_id - self.launch_descriptors[launch_id] = FfiLaunchDesc(static_inputs, launch_dims) + self.launch_descriptors[launch_id] = FfiLaunchDesc( + static_inputs, launch_dims if not infer_launch_dims else None + ) self.launch_id += 1 return call(*args, launch_id=launch_id) @@ -343,19 +375,23 @@ def ffi_callback(self, call_frame): assert num_inputs == self.num_inputs assert num_outputs == self.num_outputs - launch_bounds = launch_bounds_t(launch_desc.launch_dims) - # first kernel param is the launch bounds kernel_params = (ctypes.c_void_p * (1 + self.num_kernel_args))() - kernel_params[0] = ctypes.addressof(launch_bounds) - arg_refs = [] + batch_size = None # input and in-out args for i, input_arg in enumerate(self.input_args): if input_arg.is_array: buffer = inputs[i].contents - shape = buffer.dims[: input_arg.type.ndim] + shape = buffer.dims[: buffer.rank - input_arg.dtype_ndim] + if buffer.rank > input_arg.jax_ndim: + # handle batching + shape = collapse_batch_dims(shape, input_arg.type.ndim) + if batch_size is None: + batch_size = compute_batch_size( + buffer.dims[: buffer.rank], buffer.rank - input_arg.jax_ndim + ) strides = strides_from_shape(shape, input_arg.type.dtype) arg = array_t(buffer.data, 0, input_arg.type.ndim, shape, strides) kernel_params[i + 1] = ctypes.addressof(arg) @@ -370,12 +406,34 @@ def ffi_callback(self, call_frame): # pure output args (skip in-out FFI buffers) for i, output_arg in enumerate(self.output_args): buffer = outputs[i + self.num_in_out].contents - shape = buffer.dims[: output_arg.type.ndim] + shape = buffer.dims[: buffer.rank - output_arg.dtype_ndim] + if buffer.rank > output_arg.jax_ndim: + # handle batching + shape = collapse_batch_dims(shape, output_arg.type.ndim) + if batch_size is None: + batch_size = compute_batch_size( + buffer.dims[: buffer.rank], buffer.rank - output_arg.jax_ndim + ) strides = strides_from_shape(shape, output_arg.type.dtype) arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides) kernel_params[num_inputs + i + 1] = ctypes.addressof(arg) arg_refs.append(arg) # keep a reference + # determine launch bounds + if launch_desc.launch_dims is None: + # infer launch dims from argument shape, works with vmap + arr = arg_refs[self.first_array_arg] + launch_dims = arr.shape[: arr.ndim] + else: + # use specified launch dims + launch_dims = launch_desc.launch_dims + if batch_size is not None: + # roll batch size into the first launch dimension + launch_dims = (batch_size * launch_dims[0], *launch_dims[1:]) + + launch_bounds = launch_bounds_t(launch_dims) + kernel_params[0] = ctypes.addressof(launch_bounds) + # get device and stream device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) stream = get_stream_from_callframe(call_frame.contents) @@ -808,7 +866,7 @@ def ffi_callback(self, call_frame): for i, arg in enumerate(self.input_args): if arg.is_array: buffer = inputs[i].contents - shape = buffer.dims[: buffer.rank - arg.dtype_ndim] + shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) arg_list.append(arr) else: @@ -819,7 +877,7 @@ def ffi_callback(self, call_frame): # pure output args (skip in-out FFI buffers) for i, arg in enumerate(self.output_args): buffer = outputs[i + self.num_in_out].contents - shape = buffer.dims[: buffer.rank - arg.dtype_ndim] + shape = collapse_batch_dims(buffer.dims[: buffer.rank - arg.dtype_ndim], arg.type.ndim) arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) arg_list.append(arr) @@ -1095,6 +1153,7 @@ def jax_kernel( in_out_argnames=None, module_preload_mode=ModulePreloadMode.CURRENT_DEVICE, enable_backward: bool = False, + has_side_effect: bool = False, ): """Create a JAX callback from a Warp kernel. @@ -1103,21 +1162,23 @@ def jax_kernel( Args: kernel: The Warp kernel to launch. num_outputs: Specify the number of output arguments if greater than 1. - This must include the number of ``in_out_arguments``. + This must include the number of ``in_out_arguments``. vmap_method: String specifying how the callback transforms under ``vmap()``. - This argument can also be specified for individual calls. + This argument can also be specified for individual calls. launch_dims: Specify the default kernel launch dimensions. If None, launch - dimensions are inferred from the shape of the first array argument. - This argument can also be specified for individual calls. + dimensions are inferred from the shape of the first array argument. + This argument can also be specified for individual calls. output_dims: Specify the default dimensions of output arrays. If None, output - dimensions are inferred from the launch dimensions. - This argument can also be specified for individual calls. + dimensions are inferred from the launch dimensions. + This argument can also be specified for individual calls. in_out_argnames: Names of arguments that are both inputs and outputs (aliased buffers). These must be array arguments that appear before any pure output arguments in the kernel signature. The number of in-out arguments is included in ``num_outputs``. Not supported when ``enable_backward=True``. module_preload_mode: Specify the devices where the module should be preloaded. enable_backward: Enable automatic differentiation for this kernel. + has_side_effect: Whether the custom call has side effects. When True, + the FFI call will be executed even when the outputs are not used. Limitations: - All kernel arguments must be contiguous arrays or scalars. @@ -1129,21 +1190,41 @@ def jax_kernel( check_jax_version() + if isinstance(output_dims, dict): + hashable_output_dims = tuple(sorted(output_dims.items())) + elif hasattr(output_dims, "__len__"): + hashable_output_dims = tuple(output_dims) + else: + hashable_output_dims = output_dims + + if hasattr(launch_dims, "__len__"): + hashable_launch_dims = tuple(launch_dims) + else: + hashable_launch_dims = launch_dims + if not enable_backward: key = ( kernel.func, kernel.sig, num_outputs, vmap_method, - tuple(launch_dims) if launch_dims else launch_dims, - tuple(sorted(output_dims.items())) if output_dims else output_dims, + hashable_launch_dims, + hashable_output_dims, module_preload_mode, + has_side_effect, ) with _FFI_REGISTRY_LOCK: if key not in _FFI_KERNEL_REGISTRY: new_kernel = FfiKernel( - kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames, module_preload_mode + kernel, + num_outputs, + vmap_method, + launch_dims, + output_dims, + in_out_argnames, + module_preload_mode, + has_side_effect=has_side_effect, ) _FFI_KERNEL_REGISTRY[key] = new_kernel @@ -1173,7 +1254,7 @@ def jax_kernel( static_args = [] for i, p in enumerate(parameters[:num_inputs]): param_type = p.annotation - if not isinstance(param_type, wp.array): + if not matches_array_class(param_type, wp.array): if param_type in wp._src.types.value_types: static_args.append(i) else: @@ -1183,7 +1264,7 @@ def _resolve_launch_dims(call_args): # determine launch dimensions from the shape of the first input array for i, p in enumerate(parameters[:num_inputs]): param_type = p.annotation - if isinstance(param_type, wp.array): + if matches_array_class(param_type, wp.array): arg = call_args[i] arg_shape = tuple(arg.shape) if hasattr(param_type.dtype, "_wp_scalar_type_"): @@ -1203,7 +1284,13 @@ def fwd_kernel_wrapper(*args): fwd_kernel_wrapper.__annotations__ = {p.name: p.annotation for p in parameters} fwd_kernel_wrapper.__annotations__["return"] = None - jax_fwd_kernel = jax_callable(fwd_kernel_wrapper, num_outputs=num_outputs, vmap_method=vmap_method) + jax_fwd_kernel = jax_callable( + fwd_kernel_wrapper, + num_outputs=num_outputs, + vmap_method=vmap_method, + module_preload_mode=module_preload_mode, + has_side_effect=has_side_effect, + ) # backward arguments only include static args once bwd_arg_count = 2 * parameter_count - len(static_args) @@ -1285,6 +1372,8 @@ def bwd_kernel_wrapper(*args): bwd_kernel_wrapper, num_outputs=len(bwd_input_params) - len(static_args), vmap_method=vmap_method, + module_preload_mode=module_preload_mode, + has_side_effect=has_side_effect, ) differentiable_input_indices = [i for i in range(num_inputs) if i not in static_args] @@ -1331,7 +1420,7 @@ def bwd_function(*bwd_args): if ann is None: continue # Check if annotation is a warp array type (annotation is an instance of wp.array) - is_array_ann = isinstance(ann, wp.array) + is_array_ann = matches_array_class(ann, wp.array) if not is_array_ann: continue dtype_ndim = 0 @@ -1355,6 +1444,15 @@ def bwd_function(*bwd_args): jax_func = jax.custom_vjp(jax_fwd_kernel, nondiff_argnums=tuple(static_args)) jax_func.defvjp(fwd_function, bwd_function) + key = ( + kernel.func, + kernel.sig, + num_outputs, + vmap_method, + module_preload_mode, + has_side_effect, + ) + if static_args: static_names = [parameters[i].name for i in static_args] @@ -1364,7 +1462,7 @@ def _user_callable(*args): _user_callable.__signature__ = signature # Cache differentiable wrapper - key = (kernel.func, kernel.sig, num_outputs, vmap_method, tuple(sorted(static_names))) + key = (*key, tuple(sorted(static_names))) with _FFI_REGISTRY_LOCK: cached = _FFI_DIFF_KERNEL_REGISTRY.get(key) if cached is None: @@ -1373,7 +1471,7 @@ def _user_callable(*args): return _FFI_DIFF_KERNEL_REGISTRY[key] # Cache differentiable wrapper (no static args) - key = (kernel.func, kernel.sig, num_outputs, vmap_method, ()) + key = (*key, ()) with _FFI_REGISTRY_LOCK: cached = _FFI_DIFF_KERNEL_REGISTRY.get(key) if cached is None: @@ -1426,6 +1524,8 @@ def jax_callable( graph_cache_max: Maximum number of cached graphs captured using ``GraphMode.WARP``. If ``None``, use ``warp.jax_experimental.get_jax_callable_default_graph_cache_max()``. module_preload_mode: Specify the devices where the module should be preloaded. + has_side_effect: Whether the custom call has side effects. When True, + the FFI call will be executed even when the outputs are not used. Limitations: - All kernel arguments must be contiguous arrays or scalars. @@ -1440,14 +1540,22 @@ def jax_callable( if graph_cache_max is None: graph_cache_max = FfiCallable.default_graph_cache_max + if isinstance(output_dims, dict): + hashable_output_dims = tuple(sorted(output_dims.items())) + elif hasattr(output_dims, "__len__"): + hashable_output_dims = tuple(output_dims) + else: + hashable_output_dims = output_dims + # Note: we don't include graph_cache_max in the key, it is applied below. key = ( func, num_outputs, graph_mode, vmap_method, - tuple(sorted(output_dims.items())) if output_dims else output_dims, + hashable_output_dims, module_preload_mode, + has_side_effect, ) with _FFI_REGISTRY_LOCK: diff --git a/mjx/mujoco/mjx/warp/bvh.py b/mjx/mujoco/mjx/warp/bvh.py index 2bb522f2aa..2c1fee49fa 100644 --- a/mjx/mujoco/mjx/warp/bvh.py +++ b/mjx/mujoco/mjx/warp/bvh.py @@ -48,20 +48,27 @@ **{f.name: None for f in dataclasses.fields(mjwp_types.Callback) if f.init} ) + @ffi.format_args_for_warp def _refit_bvh_shim( # Model nworld: int, flex_dim: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), flex_elem: wp.array(dtype=int), + flex_elemadr: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), + flex_radius: wp.array(dtype=float), + flex_shell: wp.array(dtype=int), + flex_shelldataadr: wp.array(dtype=int), flex_vertadr: wp.array(dtype=int), + flex_vertnum: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), geom_type: wp.array(dtype=int), nflex: int, - nflexelemdata: int, - nflexvert: int, + nflexelem: int, # Data flexvert_xpos: wp.array2d(dtype=wp.vec3), geom_xmat: wp.array2d(dtype=wp.mat33), @@ -77,15 +84,21 @@ def _refit_bvh_shim( _d.efc = _e _d.contact = _c _m.flex_dim = flex_dim + _m.flex_edge = flex_edge _m.flex_elem = flex_elem + _m.flex_elemadr = flex_elemadr + _m.flex_elemdataadr = flex_elemdataadr _m.flex_elemnum = flex_elemnum + _m.flex_radius = flex_radius + _m.flex_shell = flex_shell + _m.flex_shelldataadr = flex_shelldataadr _m.flex_vertadr = flex_vertadr + _m.flex_vertnum = flex_vertnum _m.geom_dataid = geom_dataid _m.geom_size = geom_size _m.geom_type = geom_type _m.nflex = nflex - _m.nflexelemdata = nflexelemdata - _m.nflexvert = nflexvert + _m.nflexelem = nflexelem _d.flexvert_xpos = flexvert_xpos _d.geom_xmat = geom_xmat _d.geom_xpos = geom_xpos @@ -113,15 +126,21 @@ def _refit_bvh_jax_impl( out = jf( d.qpos.shape[0], m._impl.flex_dim, + m._impl.flex_edge, m._impl.flex_elem, + m._impl.flex_elemadr, + m._impl.flex_elemdataadr, m._impl.flex_elemnum, + m._impl.flex_radius, + m._impl.flex_shell, + m._impl.flex_shelldataadr, m._impl.flex_vertadr, + m._impl.flex_vertnum, m.geom_dataid, m.geom_size, m.geom_type, m._impl.nflex, - m._impl.nflexelemdata, - m._impl.nflexvert, + m._impl.nflexelem, d._impl.flexvert_xpos, d.geom_xmat, d.geom_xpos, diff --git a/mjx/mujoco/mjx/warp/collision_driver.py b/mjx/mujoco/mjx/warp/collision_driver.py index 265bfba3ba..9d5412d9c9 100644 --- a/mjx/mujoco/mjx/warp/collision_driver.py +++ b/mjx/mujoco/mjx/warp/collision_driver.py @@ -52,8 +52,26 @@ def _collision_shim( # Model nworld: int, block_dim: mjwp_types.BlockDim, + flex_conaffinity: wp.array(dtype=int), + flex_condim: wp.array(dtype=int), + flex_contype: wp.array(dtype=int), + flex_dim: wp.array(dtype=int), + flex_elem: wp.array(dtype=int), + flex_elemadr: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), + flex_elemnum: wp.array(dtype=int), + flex_friction: wp.array(dtype=wp.vec3), + flex_margin: wp.array(dtype=float), + flex_radius: wp.array(dtype=float), + flex_shell: wp.array(dtype=int), + flex_shelldataadr: wp.array(dtype=int), + flex_shellnum: wp.array(dtype=int), + flex_vertadr: wp.array(dtype=int), + flex_vertflexid: wp.array(dtype=int), geom_aabb: wp.array3d(dtype=wp.vec3), + geom_conaffinity: wp.array(dtype=int), geom_condim: wp.array(dtype=int), + geom_contype: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_friction: wp.array2d(dtype=wp.vec3), geom_gap: wp.array2d(dtype=float), @@ -90,6 +108,10 @@ def _collision_shim( mesh_vert: wp.array(dtype=wp.vec3), mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), + nflex: int, + nflexelem: int, + nflexshelldata: int, + nflexvert: int, ngeom: int, nmaxmeshdeg: int, nmaxpolygon: int, @@ -108,7 +130,7 @@ def _collision_shim( pair_solref: wp.array2d(dtype=wp.vec2), pair_solreffriction: wp.array2d(dtype=wp.vec2), plugin: wp.array(dtype=int), - plugin_attr: wp.array(dtype=wp.vec3f), + plugin_attr: wp.array(dtype=mjwp_types.vec_pluginattr), opt__broadphase: int, opt__broadphase_filter: int, opt__ccd_iterations: int, @@ -120,12 +142,15 @@ def _collision_shim( # Data naccdmax: int, naconmax: int, + flexvert_xpos: wp.array2d(dtype=wp.vec3), geom_xmat: wp.array2d(dtype=wp.mat33), geom_xpos: wp.array2d(dtype=wp.vec3), nacon: wp.array(dtype=int), ncollision: wp.array(dtype=int), contact__dim: wp.array(dtype=int), contact__dist: wp.array(dtype=float), + contact__efc_address: wp.array2d(dtype=int), + contact__flex: wp.array(dtype=wp.vec2i), contact__frame: wp.array(dtype=wp.mat33), contact__friction: wp.array(dtype=mjwp_types.vec5), contact__geom: wp.array(dtype=wp.vec2i), @@ -136,6 +161,7 @@ def _collision_shim( contact__solref: wp.array(dtype=wp.vec2), contact__solreffriction: wp.array(dtype=wp.vec2), contact__type: wp.array(dtype=int), + contact__vert: wp.array(dtype=wp.vec2i), contact__worldid: wp.array(dtype=int), ): _m.stat = _s @@ -144,8 +170,26 @@ def _collision_shim( _d.efc = _e _d.contact = _c _m.block_dim = block_dim + _m.flex_conaffinity = flex_conaffinity + _m.flex_condim = flex_condim + _m.flex_contype = flex_contype + _m.flex_dim = flex_dim + _m.flex_elem = flex_elem + _m.flex_elemadr = flex_elemadr + _m.flex_elemdataadr = flex_elemdataadr + _m.flex_elemnum = flex_elemnum + _m.flex_friction = flex_friction + _m.flex_margin = flex_margin + _m.flex_radius = flex_radius + _m.flex_shell = flex_shell + _m.flex_shelldataadr = flex_shelldataadr + _m.flex_shellnum = flex_shellnum + _m.flex_vertadr = flex_vertadr + _m.flex_vertflexid = flex_vertflexid _m.geom_aabb = geom_aabb + _m.geom_conaffinity = geom_conaffinity _m.geom_condim = geom_condim + _m.geom_contype = geom_contype _m.geom_dataid = geom_dataid _m.geom_friction = geom_friction _m.geom_gap = geom_gap @@ -182,6 +226,10 @@ def _collision_shim( _m.mesh_vert = mesh_vert _m.mesh_vertadr = mesh_vertadr _m.mesh_vertnum = mesh_vertnum + _m.nflex = nflex + _m.nflexelem = nflexelem + _m.nflexshelldata = nflexshelldata + _m.nflexvert = nflexvert _m.ngeom = ngeom _m.nmaxmeshdeg = nmaxmeshdeg _m.nmaxpolygon = nmaxpolygon @@ -211,6 +259,8 @@ def _collision_shim( _m.plugin_attr = plugin_attr _d.contact.dim = contact__dim _d.contact.dist = contact__dist + _d.contact.efc_address = contact__efc_address + _d.contact.flex = contact__flex _d.contact.frame = contact__frame _d.contact.friction = contact__friction _d.contact.geom = contact__geom @@ -221,7 +271,9 @@ def _collision_shim( _d.contact.solref = contact__solref _d.contact.solreffriction = contact__solreffriction _d.contact.type = contact__type + _d.contact.vert = contact__vert _d.contact.worldid = contact__worldid + _d.flexvert_xpos = flexvert_xpos _d.geom_xmat = geom_xmat _d.geom_xpos = geom_xpos _d.naccdmax = naccdmax @@ -238,6 +290,8 @@ def _collision_jax_impl(m: types.Model, d: types.Data): 'ncollision': d._impl.ncollision.shape, 'contact__dim': d._impl.contact__dim.shape, 'contact__dist': d._impl.contact__dist.shape, + 'contact__efc_address': d._impl.contact__efc_address.shape, + 'contact__flex': d._impl.contact__flex.shape, 'contact__frame': d._impl.contact__frame.shape, 'contact__friction': d._impl.contact__friction.shape, 'contact__geom': d._impl.contact__geom.shape, @@ -248,11 +302,12 @@ def _collision_jax_impl(m: types.Model, d: types.Data): 'contact__solref': d._impl.contact__solref.shape, 'contact__solreffriction': d._impl.contact__solreffriction.shape, 'contact__type': d._impl.contact__type.shape, + 'contact__vert': d._impl.contact__vert.shape, 'contact__worldid': d._impl.contact__worldid.shape, } jf = ffi.jax_callable_variadic_tuple( _collision_shim, - num_outputs=15, + num_outputs=18, output_dims=output_dims, vmap_method=None, in_out_argnames=set([ @@ -260,6 +315,8 @@ def _collision_jax_impl(m: types.Model, d: types.Data): 'ncollision', 'contact__dim', 'contact__dist', + 'contact__efc_address', + 'contact__flex', 'contact__frame', 'contact__friction', 'contact__geom', @@ -270,6 +327,7 @@ def _collision_jax_impl(m: types.Model, d: types.Data): 'contact__solref', 'contact__solreffriction', 'contact__type', + 'contact__vert', 'contact__worldid', ]), stage_in_argnames=set([ @@ -299,8 +357,26 @@ def _collision_jax_impl(m: types.Model, d: types.Data): out = jf( d.qpos.shape[0], m._impl.block_dim, + m._impl.flex_conaffinity, + m._impl.flex_condim, + m._impl.flex_contype, + m._impl.flex_dim, + m._impl.flex_elem, + m._impl.flex_elemadr, + m._impl.flex_elemdataadr, + m._impl.flex_elemnum, + m._impl.flex_friction, + m._impl.flex_margin, + m._impl.flex_radius, + m._impl.flex_shell, + m._impl.flex_shelldataadr, + m._impl.flex_shellnum, + m._impl.flex_vertadr, + m._impl.flex_vertflexid, m.geom_aabb, + m.geom_conaffinity, m.geom_condim, + m.geom_contype, m.geom_dataid, m.geom_friction, m.geom_gap, @@ -337,6 +413,10 @@ def _collision_jax_impl(m: types.Model, d: types.Data): m.mesh_vert, m.mesh_vertadr, m.mesh_vertnum, + m._impl.nflex, + m._impl.nflexelem, + m._impl.nflexshelldata, + m._impl.nflexvert, m.ngeom, m._impl.nmaxmeshdeg, m._impl.nmaxpolygon, @@ -366,12 +446,15 @@ def _collision_jax_impl(m: types.Model, d: types.Data): m.opt._impl.sdf_iterations, d._impl.naccdmax, d._impl.naconmax, + d._impl.flexvert_xpos, d.geom_xmat, d.geom_xpos, d._impl.nacon, d._impl.ncollision, d._impl.contact__dim, d._impl.contact__dist, + d._impl.contact__efc_address, + d._impl.contact__flex, d._impl.contact__frame, d._impl.contact__friction, d._impl.contact__geom, @@ -382,6 +465,7 @@ def _collision_jax_impl(m: types.Model, d: types.Data): d._impl.contact__solref, d._impl.contact__solreffriction, d._impl.contact__type, + d._impl.contact__vert, d._impl.contact__worldid, ) d = d.tree_replace({ @@ -389,17 +473,20 @@ def _collision_jax_impl(m: types.Model, d: types.Data): '_impl.ncollision': out[1], '_impl.contact__dim': out[2], '_impl.contact__dist': out[3], - '_impl.contact__frame': out[4], - '_impl.contact__friction': out[5], - '_impl.contact__geom': out[6], - '_impl.contact__geomcollisionid': out[7], - '_impl.contact__includemargin': out[8], - '_impl.contact__pos': out[9], - '_impl.contact__solimp': out[10], - '_impl.contact__solref': out[11], - '_impl.contact__solreffriction': out[12], - '_impl.contact__type': out[13], - '_impl.contact__worldid': out[14], + '_impl.contact__efc_address': out[4], + '_impl.contact__flex': out[5], + '_impl.contact__frame': out[6], + '_impl.contact__friction': out[7], + '_impl.contact__geom': out[8], + '_impl.contact__geomcollisionid': out[9], + '_impl.contact__includemargin': out[10], + '_impl.contact__pos': out[11], + '_impl.contact__solimp': out[12], + '_impl.contact__solref': out[13], + '_impl.contact__solreffriction': out[14], + '_impl.contact__type': out[15], + '_impl.contact__vert': out[16], + '_impl.contact__worldid': out[17], }) return d diff --git a/mjx/mujoco/mjx/warp/forward.py b/mjx/mujoco/mjx/warp/forward.py index e38336011e..e0f5c05323 100644 --- a/mjx/mujoco/mjx/warp/forward.py +++ b/mjx/mujoco/mjx/warp/forward.py @@ -138,6 +138,10 @@ def _forward_shim( eq_type: wp.array(dtype=int), eq_wld_adr: wp.array(dtype=int), flex_bending: wp.array2d(dtype=float), + flex_centered: wp.array(dtype=bool), + flex_conaffinity: wp.array(dtype=int), + flex_condim: wp.array(dtype=int), + flex_contype: wp.array(dtype=int), flex_damping: wp.array(dtype=float), flex_dim: wp.array(dtype=int), flex_edge: wp.array(dtype=wp.vec2i), @@ -146,12 +150,22 @@ def _forward_shim( flex_edgenum: wp.array(dtype=int), flex_elem: wp.array(dtype=int), flex_elemadr: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_elemedge: wp.array(dtype=int), flex_elemedgeadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), + flex_friction: wp.array(dtype=wp.vec3), + flex_margin: wp.array(dtype=float), + flex_radius: wp.array(dtype=float), + flex_shell: wp.array(dtype=int), + flex_shelldataadr: wp.array(dtype=int), + flex_shellnum: wp.array(dtype=int), flex_stiffness: wp.array2d(dtype=float), + flex_vert: wp.array(dtype=wp.vec3), flex_vertadr: wp.array(dtype=int), flex_vertbodyid: wp.array(dtype=int), + flex_vertflexid: wp.array(dtype=int), + flex_vertnum: wp.array(dtype=int), flexedge_J_colind: wp.array(dtype=int), flexedge_J_rowadr: wp.array(dtype=int), flexedge_J_rownnz: wp.array(dtype=int), @@ -159,7 +173,9 @@ def _forward_shim( flexedge_length0: wp.array(dtype=float), geom_aabb: wp.array3d(dtype=wp.vec3), geom_bodyid: wp.array(dtype=int), + geom_conaffinity: wp.array(dtype=int), geom_condim: wp.array(dtype=int), + geom_contype: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_fluid: wp.array2d(dtype=float), geom_friction: wp.array2d(dtype=wp.vec3), @@ -213,6 +229,7 @@ def _forward_shim( light_targetbodyid: wp.array(dtype=int), mapM2M: wp.array(dtype=int), mat_rgba: wp.array2d(dtype=wp.vec4), + max_ten_J_rownnz: int, mesh_face: wp.array(dtype=wp.vec3i), mesh_faceadr: wp.array(dtype=int), mesh_graph: wp.array(dtype=int), @@ -235,6 +252,7 @@ def _forward_shim( mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), nC: int, + nJten: int, na: int, nacttrnbody: int, nbody: int, @@ -244,6 +262,7 @@ def _forward_shim( nflex: int, nflexedge: int, nflexelem: int, + nflexshelldata: int, nflexvert: int, ngeom: int, ngravcomp: int, @@ -279,7 +298,9 @@ def _forward_shim( pair_solref: wp.array2d(dtype=wp.vec2), pair_solreffriction: wp.array2d(dtype=wp.vec2), plugin: wp.array(dtype=int), - plugin_attr: wp.array(dtype=wp.vec3f), + plugin_attr: wp.array(dtype=mjwp_types.vec_pluginattr), + qLD_all_updates: wp.array(dtype=wp.vec3i), + qLD_level_offsets: wp.array(dtype=int), qLD_updates: tuple[wp.array(dtype=wp.vec3i), ...], qM_fullm_i: wp.array(dtype=int), qM_fullm_j: wp.array(dtype=int), @@ -323,6 +344,9 @@ def _forward_shim( site_type: wp.array(dtype=int), taxel_sensorid: wp.array(dtype=int), taxel_vertadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_rownnz: wp.array(dtype=int), tendon_actfrclimited: wp.array(dtype=bool), tendon_actfrcrange: wp.array2d(dtype=wp.vec2), tendon_adr: wp.array(dtype=int), @@ -382,6 +406,7 @@ def _forward_shim( naccdmax: int, naconmax: int, njmax: int, + njmax_nnz: int, act: wp.array2d(dtype=float), act_dot: wp.array2d(dtype=float), actuator_force: wp.array2d(dtype=float), @@ -446,7 +471,7 @@ def _forward_shim( subtree_angmom: wp.array2d(dtype=wp.vec3), subtree_com: wp.array2d(dtype=wp.vec3), subtree_linvel: wp.array2d(dtype=wp.vec3), - ten_J: wp.array3d(dtype=float), + ten_J: wp.array2d(dtype=float), ten_length: wp.array2d(dtype=float), ten_velocity: wp.array2d(dtype=float), ten_wrapadr: wp.array2d(dtype=int), @@ -466,6 +491,7 @@ def _forward_shim( contact__dim: wp.array(dtype=int), contact__dist: wp.array(dtype=float), contact__efc_address: wp.array2d(dtype=int), + contact__flex: wp.array(dtype=wp.vec2i), contact__frame: wp.array(dtype=wp.mat33), contact__friction: wp.array(dtype=mjwp_types.vec5), contact__geom: wp.array(dtype=wp.vec2i), @@ -476,6 +502,7 @@ def _forward_shim( contact__solref: wp.array(dtype=wp.vec2), contact__solreffriction: wp.array(dtype=wp.vec2), contact__type: wp.array(dtype=int), + contact__vert: wp.array(dtype=wp.vec2i), contact__worldid: wp.array(dtype=int), efc__D: wp.array2d(dtype=float), efc__J: wp.array3d(dtype=float), @@ -585,6 +612,10 @@ def _forward_shim( _m.eq_type = eq_type _m.eq_wld_adr = eq_wld_adr _m.flex_bending = flex_bending + _m.flex_centered = flex_centered + _m.flex_conaffinity = flex_conaffinity + _m.flex_condim = flex_condim + _m.flex_contype = flex_contype _m.flex_damping = flex_damping _m.flex_dim = flex_dim _m.flex_edge = flex_edge @@ -593,12 +624,22 @@ def _forward_shim( _m.flex_edgenum = flex_edgenum _m.flex_elem = flex_elem _m.flex_elemadr = flex_elemadr + _m.flex_elemdataadr = flex_elemdataadr _m.flex_elemedge = flex_elemedge _m.flex_elemedgeadr = flex_elemedgeadr _m.flex_elemnum = flex_elemnum + _m.flex_friction = flex_friction + _m.flex_margin = flex_margin + _m.flex_radius = flex_radius + _m.flex_shell = flex_shell + _m.flex_shelldataadr = flex_shelldataadr + _m.flex_shellnum = flex_shellnum _m.flex_stiffness = flex_stiffness + _m.flex_vert = flex_vert _m.flex_vertadr = flex_vertadr _m.flex_vertbodyid = flex_vertbodyid + _m.flex_vertflexid = flex_vertflexid + _m.flex_vertnum = flex_vertnum _m.flexedge_J_colind = flexedge_J_colind _m.flexedge_J_rowadr = flexedge_J_rowadr _m.flexedge_J_rownnz = flexedge_J_rownnz @@ -606,7 +647,9 @@ def _forward_shim( _m.flexedge_length0 = flexedge_length0 _m.geom_aabb = geom_aabb _m.geom_bodyid = geom_bodyid + _m.geom_conaffinity = geom_conaffinity _m.geom_condim = geom_condim + _m.geom_contype = geom_contype _m.geom_dataid = geom_dataid _m.geom_fluid = geom_fluid _m.geom_friction = geom_friction @@ -660,6 +703,7 @@ def _forward_shim( _m.light_targetbodyid = light_targetbodyid _m.mapM2M = mapM2M _m.mat_rgba = mat_rgba + _m.max_ten_J_rownnz = max_ten_J_rownnz _m.mesh_face = mesh_face _m.mesh_faceadr = mesh_faceadr _m.mesh_graph = mesh_graph @@ -682,6 +726,7 @@ def _forward_shim( _m.mesh_vertadr = mesh_vertadr _m.mesh_vertnum = mesh_vertnum _m.nC = nC + _m.nJten = nJten _m.na = na _m.nacttrnbody = nacttrnbody _m.nbody = nbody @@ -691,6 +736,7 @@ def _forward_shim( _m.nflex = nflex _m.nflexedge = nflexedge _m.nflexelem = nflexelem + _m.nflexshelldata = nflexshelldata _m.nflexvert = nflexvert _m.ngeom = ngeom _m.ngravcomp = ngravcomp @@ -753,6 +799,8 @@ def _forward_shim( _m.pair_solreffriction = pair_solreffriction _m.plugin = plugin _m.plugin_attr = plugin_attr + _m.qLD_all_updates = qLD_all_updates + _m.qLD_level_offsets = qLD_level_offsets _m.qLD_updates = qLD_updates _m.qM_fullm_i = qM_fullm_i _m.qM_fullm_j = qM_fullm_j @@ -797,6 +845,9 @@ def _forward_shim( _m.stat.meaninertia = stat__meaninertia _m.taxel_sensorid = taxel_sensorid _m.taxel_vertadr = taxel_vertadr + _m.ten_J_colind = ten_J_colind + _m.ten_J_rowadr = ten_J_rowadr + _m.ten_J_rownnz = ten_J_rownnz _m.tendon_actfrclimited = tendon_actfrclimited _m.tendon_actfrcrange = tendon_actfrcrange _m.tendon_adr = tendon_adr @@ -842,6 +893,7 @@ def _forward_shim( _d.contact.dim = contact__dim _d.contact.dist = contact__dist _d.contact.efc_address = contact__efc_address + _d.contact.flex = contact__flex _d.contact.frame = contact__frame _d.contact.friction = contact__friction _d.contact.geom = contact__geom @@ -852,6 +904,7 @@ def _forward_shim( _d.contact.solref = contact__solref _d.contact.solreffriction = contact__solreffriction _d.contact.type = contact__type + _d.contact.vert = contact__vert _d.contact.worldid = contact__worldid _d.crb = crb _d.ctrl = ctrl @@ -895,6 +948,7 @@ def _forward_shim( _d.nf = nf _d.nisland = nisland _d.njmax = njmax + _d.njmax_nnz = njmax_nnz _d.nl = nl _d.qLD = qLD _d.qLDiagInv = qLDiagInv @@ -1018,6 +1072,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data): 'contact__dim': d._impl.contact__dim.shape, 'contact__dist': d._impl.contact__dist.shape, 'contact__efc_address': d._impl.contact__efc_address.shape, + 'contact__flex': d._impl.contact__flex.shape, 'contact__frame': d._impl.contact__frame.shape, 'contact__friction': d._impl.contact__friction.shape, 'contact__geom': d._impl.contact__geom.shape, @@ -1028,6 +1083,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data): 'contact__solref': d._impl.contact__solref.shape, 'contact__solreffriction': d._impl.contact__solreffriction.shape, 'contact__type': d._impl.contact__type.shape, + 'contact__vert': d._impl.contact__vert.shape, 'contact__worldid': d._impl.contact__worldid.shape, 'efc__D': d._impl.efc__D.shape, 'efc__J': d._impl.efc__J.shape, @@ -1047,7 +1103,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data): } jf = ffi.jax_callable_variadic_tuple( _forward_shim, - num_outputs=100, + num_outputs=102, output_dims=output_dims, vmap_method=None, in_out_argnames=set([ @@ -1125,6 +1181,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data): 'contact__dim', 'contact__dist', 'contact__efc_address', + 'contact__flex', 'contact__frame', 'contact__friction', 'contact__geom', @@ -1135,6 +1192,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data): 'contact__solref', 'contact__solreffriction', 'contact__type', + 'contact__vert', 'contact__worldid', 'efc__D', 'efc__J', @@ -1417,6 +1475,10 @@ def _forward_jax_impl(m: types.Model, d: types.Data): m.eq_type, m._impl.eq_wld_adr, m._impl.flex_bending, + m._impl.flex_centered, + m._impl.flex_conaffinity, + m._impl.flex_condim, + m._impl.flex_contype, m._impl.flex_damping, m._impl.flex_dim, m._impl.flex_edge, @@ -1425,12 +1487,22 @@ def _forward_jax_impl(m: types.Model, d: types.Data): m._impl.flex_edgenum, m._impl.flex_elem, m._impl.flex_elemadr, + m._impl.flex_elemdataadr, m._impl.flex_elemedge, m._impl.flex_elemedgeadr, m._impl.flex_elemnum, + m._impl.flex_friction, + m._impl.flex_margin, + m._impl.flex_radius, + m._impl.flex_shell, + m._impl.flex_shelldataadr, + m._impl.flex_shellnum, m._impl.flex_stiffness, + m._impl.flex_vert, m._impl.flex_vertadr, m._impl.flex_vertbodyid, + m._impl.flex_vertflexid, + m._impl.flex_vertnum, m._impl.flexedge_J_colind, m._impl.flexedge_J_rowadr, m._impl.flexedge_J_rownnz, @@ -1438,7 +1510,9 @@ def _forward_jax_impl(m: types.Model, d: types.Data): m._impl.flexedge_length0, m.geom_aabb, m.geom_bodyid, + m.geom_conaffinity, m.geom_condim, + m.geom_contype, m.geom_dataid, m.geom_fluid, m.geom_friction, @@ -1492,6 +1566,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data): m._impl.light_targetbodyid, m._impl.mapM2M, m.mat_rgba, + m._impl.max_ten_J_rownnz, m.mesh_face, m.mesh_faceadr, m.mesh_graph, @@ -1514,6 +1589,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data): m.mesh_vertadr, m.mesh_vertnum, m.nC, + m.nJten, m.na, m._impl.nacttrnbody, m.nbody, @@ -1523,6 +1599,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data): m._impl.nflex, m._impl.nflexedge, m._impl.nflexelem, + m._impl.nflexshelldata, m._impl.nflexvert, m.ngeom, m.ngravcomp, @@ -1559,6 +1636,8 @@ def _forward_jax_impl(m: types.Model, d: types.Data): m.pair_solreffriction, m._impl.plugin, m._impl.plugin_attr, + m._impl.qLD_all_updates, + m._impl.qLD_level_offsets, m._impl.qLD_updates, m._impl.qM_fullm_i, m._impl.qM_fullm_j, @@ -1602,6 +1681,9 @@ def _forward_jax_impl(m: types.Model, d: types.Data): m.site_type, m._impl.taxel_sensorid, m._impl.taxel_vertadr, + m._impl.ten_J_colind, + m._impl.ten_J_rowadr, + m._impl.ten_J_rownnz, m.tendon_actfrclimited, m.tendon_actfrcrange, m.tendon_adr, @@ -1660,6 +1742,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data): d._impl.naccdmax, d._impl.naconmax, d._impl.njmax, + d._impl.njmax_nnz, d.act, d.act_dot, d.actuator_force, @@ -1744,6 +1827,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data): d._impl.contact__dim, d._impl.contact__dist, d._impl.contact__efc_address, + d._impl.contact__flex, d._impl.contact__frame, d._impl.contact__friction, d._impl.contact__geom, @@ -1754,6 +1838,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data): d._impl.contact__solref, d._impl.contact__solreffriction, d._impl.contact__type, + d._impl.contact__vert, d._impl.contact__worldid, d._impl.efc__D, d._impl.efc__J, @@ -1846,32 +1931,34 @@ def _forward_jax_impl(m: types.Model, d: types.Data): '_impl.contact__dim': out[71], '_impl.contact__dist': out[72], '_impl.contact__efc_address': out[73], - '_impl.contact__frame': out[74], - '_impl.contact__friction': out[75], - '_impl.contact__geom': out[76], - '_impl.contact__geomcollisionid': out[77], - '_impl.contact__includemargin': out[78], - '_impl.contact__pos': out[79], - '_impl.contact__solimp': out[80], - '_impl.contact__solref': out[81], - '_impl.contact__solreffriction': out[82], - '_impl.contact__type': out[83], - '_impl.contact__worldid': out[84], - '_impl.efc__D': out[85], - '_impl.efc__J': out[86], - '_impl.efc__J_colind': out[87], - '_impl.efc__J_rowadr': out[88], - '_impl.efc__J_rownnz': out[89], - '_impl.efc__Ma': out[90], - '_impl.efc__aref': out[91], - '_impl.efc__force': out[92], - '_impl.efc__frictionloss': out[93], - '_impl.efc__id': out[94], - '_impl.efc__margin': out[95], - '_impl.efc__pos': out[96], - '_impl.efc__state': out[97], - '_impl.efc__type': out[98], - '_impl.efc__vel': out[99], + '_impl.contact__flex': out[74], + '_impl.contact__frame': out[75], + '_impl.contact__friction': out[76], + '_impl.contact__geom': out[77], + '_impl.contact__geomcollisionid': out[78], + '_impl.contact__includemargin': out[79], + '_impl.contact__pos': out[80], + '_impl.contact__solimp': out[81], + '_impl.contact__solref': out[82], + '_impl.contact__solreffriction': out[83], + '_impl.contact__type': out[84], + '_impl.contact__vert': out[85], + '_impl.contact__worldid': out[86], + '_impl.efc__D': out[87], + '_impl.efc__J': out[88], + '_impl.efc__J_colind': out[89], + '_impl.efc__J_rowadr': out[90], + '_impl.efc__J_rownnz': out[91], + '_impl.efc__Ma': out[92], + '_impl.efc__aref': out[93], + '_impl.efc__force': out[94], + '_impl.efc__frictionloss': out[95], + '_impl.efc__id': out[96], + '_impl.efc__margin': out[97], + '_impl.efc__pos': out[98], + '_impl.efc__state': out[99], + '_impl.efc__type': out[100], + '_impl.efc__vel': out[101], }) return d @@ -1980,6 +2067,10 @@ def _step_shim( eq_type: wp.array(dtype=int), eq_wld_adr: wp.array(dtype=int), flex_bending: wp.array2d(dtype=float), + flex_centered: wp.array(dtype=bool), + flex_conaffinity: wp.array(dtype=int), + flex_condim: wp.array(dtype=int), + flex_contype: wp.array(dtype=int), flex_damping: wp.array(dtype=float), flex_dim: wp.array(dtype=int), flex_edge: wp.array(dtype=wp.vec2i), @@ -1988,12 +2079,22 @@ def _step_shim( flex_edgenum: wp.array(dtype=int), flex_elem: wp.array(dtype=int), flex_elemadr: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_elemedge: wp.array(dtype=int), flex_elemedgeadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), + flex_friction: wp.array(dtype=wp.vec3), + flex_margin: wp.array(dtype=float), + flex_radius: wp.array(dtype=float), + flex_shell: wp.array(dtype=int), + flex_shelldataadr: wp.array(dtype=int), + flex_shellnum: wp.array(dtype=int), flex_stiffness: wp.array2d(dtype=float), + flex_vert: wp.array(dtype=wp.vec3), flex_vertadr: wp.array(dtype=int), flex_vertbodyid: wp.array(dtype=int), + flex_vertflexid: wp.array(dtype=int), + flex_vertnum: wp.array(dtype=int), flexedge_J_colind: wp.array(dtype=int), flexedge_J_rowadr: wp.array(dtype=int), flexedge_J_rownnz: wp.array(dtype=int), @@ -2001,7 +2102,9 @@ def _step_shim( flexedge_length0: wp.array(dtype=float), geom_aabb: wp.array3d(dtype=wp.vec3), geom_bodyid: wp.array(dtype=int), + geom_conaffinity: wp.array(dtype=int), geom_condim: wp.array(dtype=int), + geom_contype: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_fluid: wp.array2d(dtype=float), geom_friction: wp.array2d(dtype=wp.vec3), @@ -2055,6 +2158,7 @@ def _step_shim( light_targetbodyid: wp.array(dtype=int), mapM2M: wp.array(dtype=int), mat_rgba: wp.array2d(dtype=wp.vec4), + max_ten_J_rownnz: int, mesh_face: wp.array(dtype=wp.vec3i), mesh_faceadr: wp.array(dtype=int), mesh_graph: wp.array(dtype=int), @@ -2077,6 +2181,7 @@ def _step_shim( mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), nC: int, + nJten: int, nM: int, na: int, nacttrnbody: int, @@ -2087,6 +2192,7 @@ def _step_shim( nflex: int, nflexedge: int, nflexelem: int, + nflexshelldata: int, nflexvert: int, ngeom: int, ngravcomp: int, @@ -2122,7 +2228,9 @@ def _step_shim( pair_solref: wp.array2d(dtype=wp.vec2), pair_solreffriction: wp.array2d(dtype=wp.vec2), plugin: wp.array(dtype=int), - plugin_attr: wp.array(dtype=wp.vec3f), + plugin_attr: wp.array(dtype=mjwp_types.vec_pluginattr), + qLD_all_updates: wp.array(dtype=wp.vec3i), + qLD_level_offsets: wp.array(dtype=int), qLD_updates: tuple[wp.array(dtype=wp.vec3i), ...], qM_fullm_i: wp.array(dtype=int), qM_fullm_j: wp.array(dtype=int), @@ -2166,6 +2274,9 @@ def _step_shim( site_type: wp.array(dtype=int), taxel_sensorid: wp.array(dtype=int), taxel_vertadr: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_rownnz: wp.array(dtype=int), tendon_actfrclimited: wp.array(dtype=bool), tendon_actfrcrange: wp.array2d(dtype=wp.vec2), tendon_adr: wp.array(dtype=int), @@ -2226,6 +2337,7 @@ def _step_shim( naccdmax: int, naconmax: int, njmax: int, + njmax_nnz: int, act: wp.array2d(dtype=float), act_dot: wp.array2d(dtype=float), actuator_force: wp.array2d(dtype=float), @@ -2290,7 +2402,7 @@ def _step_shim( subtree_angmom: wp.array2d(dtype=wp.vec3), subtree_com: wp.array2d(dtype=wp.vec3), subtree_linvel: wp.array2d(dtype=wp.vec3), - ten_J: wp.array3d(dtype=float), + ten_J: wp.array2d(dtype=float), ten_length: wp.array2d(dtype=float), ten_velocity: wp.array2d(dtype=float), ten_wrapadr: wp.array2d(dtype=int), @@ -2310,6 +2422,7 @@ def _step_shim( contact__dim: wp.array(dtype=int), contact__dist: wp.array(dtype=float), contact__efc_address: wp.array2d(dtype=int), + contact__flex: wp.array(dtype=wp.vec2i), contact__frame: wp.array(dtype=wp.mat33), contact__friction: wp.array(dtype=mjwp_types.vec5), contact__geom: wp.array(dtype=wp.vec2i), @@ -2320,6 +2433,7 @@ def _step_shim( contact__solref: wp.array(dtype=wp.vec2), contact__solreffriction: wp.array(dtype=wp.vec2), contact__type: wp.array(dtype=int), + contact__vert: wp.array(dtype=wp.vec2i), contact__worldid: wp.array(dtype=int), efc__D: wp.array2d(dtype=float), efc__J: wp.array3d(dtype=float), @@ -2429,6 +2543,10 @@ def _step_shim( _m.eq_type = eq_type _m.eq_wld_adr = eq_wld_adr _m.flex_bending = flex_bending + _m.flex_centered = flex_centered + _m.flex_conaffinity = flex_conaffinity + _m.flex_condim = flex_condim + _m.flex_contype = flex_contype _m.flex_damping = flex_damping _m.flex_dim = flex_dim _m.flex_edge = flex_edge @@ -2437,12 +2555,22 @@ def _step_shim( _m.flex_edgenum = flex_edgenum _m.flex_elem = flex_elem _m.flex_elemadr = flex_elemadr + _m.flex_elemdataadr = flex_elemdataadr _m.flex_elemedge = flex_elemedge _m.flex_elemedgeadr = flex_elemedgeadr _m.flex_elemnum = flex_elemnum + _m.flex_friction = flex_friction + _m.flex_margin = flex_margin + _m.flex_radius = flex_radius + _m.flex_shell = flex_shell + _m.flex_shelldataadr = flex_shelldataadr + _m.flex_shellnum = flex_shellnum _m.flex_stiffness = flex_stiffness + _m.flex_vert = flex_vert _m.flex_vertadr = flex_vertadr _m.flex_vertbodyid = flex_vertbodyid + _m.flex_vertflexid = flex_vertflexid + _m.flex_vertnum = flex_vertnum _m.flexedge_J_colind = flexedge_J_colind _m.flexedge_J_rowadr = flexedge_J_rowadr _m.flexedge_J_rownnz = flexedge_J_rownnz @@ -2450,7 +2578,9 @@ def _step_shim( _m.flexedge_length0 = flexedge_length0 _m.geom_aabb = geom_aabb _m.geom_bodyid = geom_bodyid + _m.geom_conaffinity = geom_conaffinity _m.geom_condim = geom_condim + _m.geom_contype = geom_contype _m.geom_dataid = geom_dataid _m.geom_fluid = geom_fluid _m.geom_friction = geom_friction @@ -2504,6 +2634,7 @@ def _step_shim( _m.light_targetbodyid = light_targetbodyid _m.mapM2M = mapM2M _m.mat_rgba = mat_rgba + _m.max_ten_J_rownnz = max_ten_J_rownnz _m.mesh_face = mesh_face _m.mesh_faceadr = mesh_faceadr _m.mesh_graph = mesh_graph @@ -2526,6 +2657,7 @@ def _step_shim( _m.mesh_vertadr = mesh_vertadr _m.mesh_vertnum = mesh_vertnum _m.nC = nC + _m.nJten = nJten _m.nM = nM _m.na = na _m.nacttrnbody = nacttrnbody @@ -2536,6 +2668,7 @@ def _step_shim( _m.nflex = nflex _m.nflexedge = nflexedge _m.nflexelem = nflexelem + _m.nflexshelldata = nflexshelldata _m.nflexvert = nflexvert _m.ngeom = ngeom _m.ngravcomp = ngravcomp @@ -2599,6 +2732,8 @@ def _step_shim( _m.pair_solreffriction = pair_solreffriction _m.plugin = plugin _m.plugin_attr = plugin_attr + _m.qLD_all_updates = qLD_all_updates + _m.qLD_level_offsets = qLD_level_offsets _m.qLD_updates = qLD_updates _m.qM_fullm_i = qM_fullm_i _m.qM_fullm_j = qM_fullm_j @@ -2643,6 +2778,9 @@ def _step_shim( _m.stat.meaninertia = stat__meaninertia _m.taxel_sensorid = taxel_sensorid _m.taxel_vertadr = taxel_vertadr + _m.ten_J_colind = ten_J_colind + _m.ten_J_rowadr = ten_J_rowadr + _m.ten_J_rownnz = ten_J_rownnz _m.tendon_actfrclimited = tendon_actfrclimited _m.tendon_actfrcrange = tendon_actfrcrange _m.tendon_adr = tendon_adr @@ -2688,6 +2826,7 @@ def _step_shim( _d.contact.dim = contact__dim _d.contact.dist = contact__dist _d.contact.efc_address = contact__efc_address + _d.contact.flex = contact__flex _d.contact.frame = contact__frame _d.contact.friction = contact__friction _d.contact.geom = contact__geom @@ -2698,6 +2837,7 @@ def _step_shim( _d.contact.solref = contact__solref _d.contact.solreffriction = contact__solreffriction _d.contact.type = contact__type + _d.contact.vert = contact__vert _d.contact.worldid = contact__worldid _d.crb = crb _d.ctrl = ctrl @@ -2741,6 +2881,7 @@ def _step_shim( _d.nf = nf _d.nisland = nisland _d.njmax = njmax + _d.njmax_nnz = njmax_nnz _d.nl = nl _d.qLD = qLD _d.qLDiagInv = qLDiagInv @@ -2868,6 +3009,7 @@ def _step_jax_impl(m: types.Model, d: types.Data): 'contact__dim': d._impl.contact__dim.shape, 'contact__dist': d._impl.contact__dist.shape, 'contact__efc_address': d._impl.contact__efc_address.shape, + 'contact__flex': d._impl.contact__flex.shape, 'contact__frame': d._impl.contact__frame.shape, 'contact__friction': d._impl.contact__friction.shape, 'contact__geom': d._impl.contact__geom.shape, @@ -2878,6 +3020,7 @@ def _step_jax_impl(m: types.Model, d: types.Data): 'contact__solref': d._impl.contact__solref.shape, 'contact__solreffriction': d._impl.contact__solreffriction.shape, 'contact__type': d._impl.contact__type.shape, + 'contact__vert': d._impl.contact__vert.shape, 'contact__worldid': d._impl.contact__worldid.shape, 'efc__D': d._impl.efc__D.shape, 'efc__J': d._impl.efc__J.shape, @@ -2897,7 +3040,7 @@ def _step_jax_impl(m: types.Model, d: types.Data): } jf = ffi.jax_callable_variadic_tuple( _step_shim, - num_outputs=104, + num_outputs=106, output_dims=output_dims, vmap_method=None, in_out_argnames=set([ @@ -2979,6 +3122,7 @@ def _step_jax_impl(m: types.Model, d: types.Data): 'contact__dim', 'contact__dist', 'contact__efc_address', + 'contact__flex', 'contact__frame', 'contact__friction', 'contact__geom', @@ -2989,6 +3133,7 @@ def _step_jax_impl(m: types.Model, d: types.Data): 'contact__solref', 'contact__solreffriction', 'contact__type', + 'contact__vert', 'contact__worldid', 'efc__D', 'efc__J', @@ -3275,6 +3420,10 @@ def _step_jax_impl(m: types.Model, d: types.Data): m.eq_type, m._impl.eq_wld_adr, m._impl.flex_bending, + m._impl.flex_centered, + m._impl.flex_conaffinity, + m._impl.flex_condim, + m._impl.flex_contype, m._impl.flex_damping, m._impl.flex_dim, m._impl.flex_edge, @@ -3283,12 +3432,22 @@ def _step_jax_impl(m: types.Model, d: types.Data): m._impl.flex_edgenum, m._impl.flex_elem, m._impl.flex_elemadr, + m._impl.flex_elemdataadr, m._impl.flex_elemedge, m._impl.flex_elemedgeadr, m._impl.flex_elemnum, + m._impl.flex_friction, + m._impl.flex_margin, + m._impl.flex_radius, + m._impl.flex_shell, + m._impl.flex_shelldataadr, + m._impl.flex_shellnum, m._impl.flex_stiffness, + m._impl.flex_vert, m._impl.flex_vertadr, m._impl.flex_vertbodyid, + m._impl.flex_vertflexid, + m._impl.flex_vertnum, m._impl.flexedge_J_colind, m._impl.flexedge_J_rowadr, m._impl.flexedge_J_rownnz, @@ -3296,7 +3455,9 @@ def _step_jax_impl(m: types.Model, d: types.Data): m._impl.flexedge_length0, m.geom_aabb, m.geom_bodyid, + m.geom_conaffinity, m.geom_condim, + m.geom_contype, m.geom_dataid, m.geom_fluid, m.geom_friction, @@ -3350,6 +3511,7 @@ def _step_jax_impl(m: types.Model, d: types.Data): m._impl.light_targetbodyid, m._impl.mapM2M, m.mat_rgba, + m._impl.max_ten_J_rownnz, m.mesh_face, m.mesh_faceadr, m.mesh_graph, @@ -3372,6 +3534,7 @@ def _step_jax_impl(m: types.Model, d: types.Data): m.mesh_vertadr, m.mesh_vertnum, m.nC, + m.nJten, m.nM, m.na, m._impl.nacttrnbody, @@ -3382,6 +3545,7 @@ def _step_jax_impl(m: types.Model, d: types.Data): m._impl.nflex, m._impl.nflexedge, m._impl.nflexelem, + m._impl.nflexshelldata, m._impl.nflexvert, m.ngeom, m.ngravcomp, @@ -3418,6 +3582,8 @@ def _step_jax_impl(m: types.Model, d: types.Data): m.pair_solreffriction, m._impl.plugin, m._impl.plugin_attr, + m._impl.qLD_all_updates, + m._impl.qLD_level_offsets, m._impl.qLD_updates, m._impl.qM_fullm_i, m._impl.qM_fullm_j, @@ -3461,6 +3627,9 @@ def _step_jax_impl(m: types.Model, d: types.Data): m.site_type, m._impl.taxel_sensorid, m._impl.taxel_vertadr, + m._impl.ten_J_colind, + m._impl.ten_J_rowadr, + m._impl.ten_J_rownnz, m.tendon_actfrclimited, m.tendon_actfrcrange, m.tendon_adr, @@ -3520,6 +3689,7 @@ def _step_jax_impl(m: types.Model, d: types.Data): d._impl.naccdmax, d._impl.naconmax, d._impl.njmax, + d._impl.njmax_nnz, d.act, d.act_dot, d.actuator_force, @@ -3604,6 +3774,7 @@ def _step_jax_impl(m: types.Model, d: types.Data): d._impl.contact__dim, d._impl.contact__dist, d._impl.contact__efc_address, + d._impl.contact__flex, d._impl.contact__frame, d._impl.contact__friction, d._impl.contact__geom, @@ -3614,6 +3785,7 @@ def _step_jax_impl(m: types.Model, d: types.Data): d._impl.contact__solref, d._impl.contact__solreffriction, d._impl.contact__type, + d._impl.contact__vert, d._impl.contact__worldid, d._impl.efc__D, d._impl.efc__J, @@ -3710,32 +3882,34 @@ def _step_jax_impl(m: types.Model, d: types.Data): '_impl.contact__dim': out[75], '_impl.contact__dist': out[76], '_impl.contact__efc_address': out[77], - '_impl.contact__frame': out[78], - '_impl.contact__friction': out[79], - '_impl.contact__geom': out[80], - '_impl.contact__geomcollisionid': out[81], - '_impl.contact__includemargin': out[82], - '_impl.contact__pos': out[83], - '_impl.contact__solimp': out[84], - '_impl.contact__solref': out[85], - '_impl.contact__solreffriction': out[86], - '_impl.contact__type': out[87], - '_impl.contact__worldid': out[88], - '_impl.efc__D': out[89], - '_impl.efc__J': out[90], - '_impl.efc__J_colind': out[91], - '_impl.efc__J_rowadr': out[92], - '_impl.efc__J_rownnz': out[93], - '_impl.efc__Ma': out[94], - '_impl.efc__aref': out[95], - '_impl.efc__force': out[96], - '_impl.efc__frictionloss': out[97], - '_impl.efc__id': out[98], - '_impl.efc__margin': out[99], - '_impl.efc__pos': out[100], - '_impl.efc__state': out[101], - '_impl.efc__type': out[102], - '_impl.efc__vel': out[103], + '_impl.contact__flex': out[78], + '_impl.contact__frame': out[79], + '_impl.contact__friction': out[80], + '_impl.contact__geom': out[81], + '_impl.contact__geomcollisionid': out[82], + '_impl.contact__includemargin': out[83], + '_impl.contact__pos': out[84], + '_impl.contact__solimp': out[85], + '_impl.contact__solref': out[86], + '_impl.contact__solreffriction': out[87], + '_impl.contact__type': out[88], + '_impl.contact__vert': out[89], + '_impl.contact__worldid': out[90], + '_impl.efc__D': out[91], + '_impl.efc__J': out[92], + '_impl.efc__J_colind': out[93], + '_impl.efc__J_rowadr': out[94], + '_impl.efc__J_rownnz': out[95], + '_impl.efc__Ma': out[96], + '_impl.efc__aref': out[97], + '_impl.efc__force': out[98], + '_impl.efc__frictionloss': out[99], + '_impl.efc__id': out[100], + '_impl.efc__margin': out[101], + '_impl.efc__pos': out[102], + '_impl.efc__state': out[103], + '_impl.efc__type': out[104], + '_impl.efc__vel': out[105], }) return d diff --git a/mjx/mujoco/mjx/warp/forward_test.py b/mjx/mujoco/mjx/warp/forward_test.py index 15e874e3eb..b80c921dcc 100644 --- a/mjx/mujoco/mjx/warp/forward_test.py +++ b/mjx/mujoco/mjx/warp/forward_test.py @@ -157,7 +157,16 @@ def test_forward(self, xml: str, batch_size: int): m.ten_J_rowadr, m.ten_J_colind, ) - tu.assert_eq(dx._impl.ten_J, ten_J, 'ten_J') + # convert sparse warp ten_J to dense representation + warp_ten_J = np.zeros((m.ntendon, m.nv)) + mujoco.mju_sparse2dense( + warp_ten_J, + np.asarray(dx._impl.ten_J), + mx._impl.ten_J_rownnz, + mx._impl.ten_J_rowadr, + mx._impl.ten_J_colind, + ) + tu.assert_eq(warp_ten_J, ten_J, 'ten_J') tu.assert_attr_eq(dx._impl, d, 'ten_wrapadr') tu.assert_attr_eq(dx._impl, d, 'ten_wrapnum') tu.assert_attr_eq(dx._impl, d, 'wrap_xpos') diff --git a/mjx/mujoco/mjx/warp/render.py b/mjx/mujoco/mjx/warp/render.py index 0bab51a3fb..e723e6b076 100644 --- a/mjx/mujoco/mjx/warp/render.py +++ b/mjx/mujoco/mjx/warp/render.py @@ -48,6 +48,7 @@ **{f.name: None for f in dataclasses.fields(mjwp_types.Callback) if f.init} ) + @ffi.format_args_for_warp def _render_shim( # Model @@ -56,6 +57,9 @@ def _render_shim( cam_intrinsic: wp.array2d(dtype=wp.vec4), cam_projection: wp.array(dtype=int), cam_sensorsize: wp.array(dtype=wp.vec2), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), + flex_vertadr: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_matid: wp.array2d(dtype=int), geom_rgba: wp.array2d(dtype=wp.vec4), @@ -68,11 +72,11 @@ def _render_shim( mat_texid: wp.array3d(dtype=int), mat_texrepeat: wp.array2d(dtype=wp.vec2), mesh_faceadr: wp.array(dtype=int), - nflex: int, nlight: int, # Data cam_xmat: wp.array2d(dtype=wp.mat33), cam_xpos: wp.array2d(dtype=wp.vec3), + flexvert_xpos: wp.array2d(dtype=wp.vec3), geom_xmat: wp.array2d(dtype=wp.mat33), geom_xpos: wp.array2d(dtype=wp.vec3), light_xdir: wp.array2d(dtype=wp.vec3), @@ -91,6 +95,9 @@ def _render_shim( _m.cam_intrinsic = cam_intrinsic _m.cam_projection = cam_projection _m.cam_sensorsize = cam_sensorsize + _m.flex_edge = flex_edge + _m.flex_radius = flex_radius + _m.flex_vertadr = flex_vertadr _m.geom_dataid = geom_dataid _m.geom_matid = geom_matid _m.geom_rgba = geom_rgba @@ -103,10 +110,10 @@ def _render_shim( _m.mat_texid = mat_texid _m.mat_texrepeat = mat_texrepeat _m.mesh_faceadr = mesh_faceadr - _m.nflex = nflex _m.nlight = nlight _d.cam_xmat = cam_xmat _d.cam_xpos = cam_xpos + _d.flexvert_xpos = flexvert_xpos _d.geom_xmat = geom_xmat _d.geom_xpos = geom_xpos _d.light_xdir = light_xdir @@ -155,6 +162,9 @@ def _render_jax_impl(m: types.Model, d: types.Data, ctx: RenderContextPytree): m.cam_intrinsic, m._impl.cam_projection, m.cam_sensorsize, + m._impl.flex_edge, + m._impl.flex_radius, + m._impl.flex_vertadr, m.geom_dataid, m.geom_matid, m.geom_rgba, @@ -167,10 +177,10 @@ def _render_jax_impl(m: types.Model, d: types.Data, ctx: RenderContextPytree): m.mat_texid, m._impl.mat_texrepeat, m.mesh_faceadr, - m._impl.nflex, m.nlight, d.cam_xmat, d.cam_xpos, + d._impl.flexvert_xpos, d.geom_xmat, d.geom_xpos, d._impl.light_xdir, diff --git a/mjx/mujoco/mjx/warp/smooth.py b/mjx/mujoco/mjx/warp/smooth.py index eeef2d924e..3fe3a712b8 100644 --- a/mjx/mujoco/mjx/warp/smooth.py +++ b/mjx/mujoco/mjx/warp/smooth.py @@ -297,17 +297,20 @@ def kinematics_vmap( def _tendon_shim( # Model nworld: int, + body_dofadr: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), geom_bodyid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), jnt_dofadr: wp.array(dtype=int), jnt_qposadr: wp.array(dtype=int), ntendon: int, - nv: int, nwrap: int, site_bodyid: wp.array(dtype=int), + ten_J_colind: wp.array(dtype=int), + ten_J_rowadr: wp.array(dtype=int), + ten_J_rownnz: wp.array(dtype=int), tendon_adr: wp.array(dtype=int), tendon_geom_adr: wp.array(dtype=int), tendon_jnt_adr: wp.array(dtype=int), @@ -327,7 +330,7 @@ def _tendon_shim( qpos: wp.array2d(dtype=float), site_xpos: wp.array2d(dtype=wp.vec3), subtree_com: wp.array2d(dtype=wp.vec3), - ten_J: wp.array3d(dtype=float), + ten_J: wp.array2d(dtype=float), ten_length: wp.array2d(dtype=float), ten_wrapadr: wp.array2d(dtype=int), ten_wrapnum: wp.array2d(dtype=int), @@ -339,17 +342,20 @@ def _tendon_shim( _m.callback = _cb _d.efc = _e _d.contact = _c + _m.body_dofadr = body_dofadr + _m.body_dofnum = body_dofnum _m.body_parentid = body_parentid _m.body_rootid = body_rootid - _m.dof_bodyid = dof_bodyid _m.geom_bodyid = geom_bodyid _m.geom_size = geom_size _m.jnt_dofadr = jnt_dofadr _m.jnt_qposadr = jnt_qposadr _m.ntendon = ntendon - _m.nv = nv _m.nwrap = nwrap _m.site_bodyid = site_bodyid + _m.ten_J_colind = ten_J_colind + _m.ten_J_rowadr = ten_J_rowadr + _m.ten_J_rownnz = ten_J_rownnz _m.tendon_adr = tendon_adr _m.tendon_geom_adr = tendon_geom_adr _m.tendon_jnt_adr = tendon_jnt_adr @@ -416,17 +422,20 @@ def _tendon_jax_impl(m: types.Model, d: types.Data): ) out = jf( d.qpos.shape[0], + m.body_dofadr, + m.body_dofnum, m.body_parentid, m.body_rootid, - m.dof_bodyid, m.geom_bodyid, m.geom_size, m.jnt_dofadr, m.jnt_qposadr, m.ntendon, - m.nv, m.nwrap, m.site_bodyid, + m._impl.ten_J_colind, + m._impl.ten_J_rowadr, + m._impl.ten_J_rownnz, m.tendon_adr, m._impl.tendon_geom_adr, m._impl.tendon_jnt_adr, diff --git a/mjx/mujoco/mjx/warp/types.py b/mjx/mujoco/mjx/warp/types.py index ff85189f6e..ef75edc40a 100644 --- a/mjx/mujoco/mjx/warp/types.py +++ b/mjx/mujoco/mjx/warp/types.py @@ -83,7 +83,7 @@ class BlockDim: qderiv_actuator_dense: int ray: int segmented_sort: int - tendon_velocity: int + solve_LD_sparse_fused: int update_gradient_JTDAJ_dense: int update_gradient_JTDAJ_sparse: int update_gradient_cholesky: int @@ -141,6 +141,10 @@ class ModelWarp(PyTreeNode): eq_ten_adr: np.ndarray eq_wld_adr: np.ndarray flex_bending: np.ndarray + flex_centered: np.ndarray + flex_conaffinity: np.ndarray + flex_condim: np.ndarray + flex_contype: np.ndarray flex_damping: np.ndarray flex_dim: np.ndarray flex_edge: np.ndarray @@ -149,12 +153,21 @@ class ModelWarp(PyTreeNode): flex_edgenum: np.ndarray flex_elem: np.ndarray flex_elemadr: np.ndarray + flex_elemdataadr: np.ndarray flex_elemedge: np.ndarray flex_elemedgeadr: np.ndarray flex_elemnum: np.ndarray + flex_friction: np.ndarray + flex_margin: np.ndarray + flex_radius: np.ndarray + flex_shell: np.ndarray + flex_shelldataadr: np.ndarray + flex_shellnum: np.ndarray flex_stiffness: np.ndarray + flex_vert: np.ndarray flex_vertadr: np.ndarray flex_vertbodyid: np.ndarray + flex_vertflexid: np.ndarray flex_vertnum: np.ndarray flexedge_J_colind: np.ndarray flexedge_J_rowadr: np.ndarray @@ -173,6 +186,7 @@ class ModelWarp(PyTreeNode): light_targetbodyid: np.ndarray mapM2M: np.ndarray mat_texrepeat: jax.Array + max_ten_J_rownnz: int mesh_polyadr: np.ndarray mesh_polymap: np.ndarray mesh_polymapadr: np.ndarray @@ -191,6 +205,7 @@ class ModelWarp(PyTreeNode): nflexelem: int nflexelemdata: int nflexelemedge: int + nflexshelldata: int nflexvert: int nmaxcondim: int nmaxmeshdeg: int @@ -213,6 +228,8 @@ class ModelWarp(PyTreeNode): oct_coeff: np.ndarray plugin: np.ndarray plugin_attr: np.ndarray + qLD_all_updates: np.ndarray + qLD_level_offsets: np.ndarray qLD_updates: Tuple[np.ndarray, ...] qM_fullm_i: np.ndarray qM_fullm_j: np.ndarray @@ -240,6 +257,9 @@ class ModelWarp(PyTreeNode): sensor_vel_adr: np.ndarray taxel_sensorid: np.ndarray taxel_vertadr: np.ndarray + ten_J_colind: np.ndarray + ten_J_rowadr: np.ndarray + ten_J_rownnz: np.ndarray ten_wrapadr_site: np.ndarray ten_wrapnum_site: np.ndarray tendon_geom_adr: np.ndarray @@ -266,6 +286,7 @@ class DataWarp(PyTreeNode): contact__dim: jax.Array contact__dist: jax.Array contact__efc_address: jax.Array + contact__flex: jax.Array contact__frame: jax.Array contact__friction: jax.Array contact__geom: jax.Array @@ -276,6 +297,7 @@ class DataWarp(PyTreeNode): contact__solref: jax.Array contact__solreffriction: jax.Array contact__type: jax.Array + contact__vert: jax.Array contact__worldid: jax.Array crb: jax.Array efc__D: jax.Array @@ -312,6 +334,7 @@ class DataWarp(PyTreeNode): nf: jax.Array nisland: jax.Array njmax: int + njmax_nnz: int njmax_pad: int nl: jax.Array nworld: int @@ -335,6 +358,7 @@ class DataWarp(PyTreeNode): 'contact__dim', 'contact__dist', 'contact__efc_address', + 'contact__flex', 'contact__frame', 'contact__friction', 'contact__geom', @@ -345,12 +369,14 @@ class DataWarp(PyTreeNode): 'contact__solref', 'contact__solreffriction', 'contact__type', + 'contact__vert', 'contact__worldid', 'naccdmax', 'nacon', 'naconmax', 'ncollision', 'njmax', + 'njmax_nnz', 'njmax_pad', 'nworld', } @@ -398,6 +424,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'contact__dim': 1, 'contact__dist': 1, 'contact__efc_address': 2, + 'contact__flex': 2, 'contact__frame': 3, 'contact__friction': 2, 'contact__geom': 2, @@ -408,6 +435,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'contact__solref': 2, 'contact__solreffriction': 2, 'contact__type': 1, + 'contact__vert': 2, 'contact__worldid': 1, 'crb': 3, 'ctrl': 2, @@ -451,6 +479,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'nf': 1, 'nisland': 1, 'njmax': 0, + 'njmax_nnz': 0, 'njmax_pad': 0, 'nl': 1, 'nworld': 0, @@ -480,7 +509,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'subtree_angmom': 3, 'subtree_com': 3, 'subtree_linvel': 3, - 'ten_J': 3, + 'ten_J': 2, 'ten_length': 2, 'ten_velocity': 2, 'ten_wrapadr': 2, @@ -535,7 +564,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'block_dim__qderiv_actuator_dense': 0, 'block_dim__ray': 0, 'block_dim__segmented_sort': 0, - 'block_dim__tendon_velocity': 0, + 'block_dim__solve_LD_sparse_fused': 0, 'block_dim__update_gradient_JTDAJ_dense': 0, 'block_dim__update_gradient_JTDAJ_sparse': 0, 'block_dim__update_gradient_cholesky': 0, @@ -608,6 +637,10 @@ def _from_elt(cont, axis_size, d, axis_dest): 'eq_wld_adr': 1, 'exclude_signature': 1, 'flex_bending': 2, + 'flex_centered': 1, + 'flex_conaffinity': 1, + 'flex_condim': 1, + 'flex_contype': 1, 'flex_damping': 1, 'flex_dim': 1, 'flex_edge': 2, @@ -616,12 +649,21 @@ def _from_elt(cont, axis_size, d, axis_dest): 'flex_edgenum': 1, 'flex_elem': 1, 'flex_elemadr': 1, + 'flex_elemdataadr': 1, 'flex_elemedge': 1, 'flex_elemedgeadr': 1, 'flex_elemnum': 1, + 'flex_friction': 2, + 'flex_margin': 1, + 'flex_radius': 1, + 'flex_shell': 1, + 'flex_shelldataadr': 1, + 'flex_shellnum': 1, 'flex_stiffness': 2, + 'flex_vert': 2, 'flex_vertadr': 1, 'flex_vertbodyid': 1, + 'flex_vertflexid': 1, 'flex_vertnum': 1, 'flexedge_J_colind': 1, 'flexedge_J_rowadr': 1, @@ -692,6 +734,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'mat_rgba': 3, 'mat_texid': 3, 'mat_texrepeat': 3, + 'max_ten_J_rownnz': 0, 'mesh_face': 2, 'mesh_faceadr': 1, 'mesh_graph': 1, @@ -717,6 +760,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'nC': 0, 'nJfe': 0, 'nJmom': 0, + 'nJten': 0, 'nM': 0, 'na': 0, 'nacttrnbody': 0, @@ -730,6 +774,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'nflexelem': 0, 'nflexelemdata': 0, 'nflexelemedge': 0, + 'nflexshelldata': 0, 'nflexvert': 0, 'ngeom': 0, 'ngravcomp': 0, @@ -811,6 +856,8 @@ def _from_elt(cont, axis_size, d, axis_dest): 'pair_solreffriction': 3, 'plugin': 1, 'plugin_attr': 2, + 'qLD_all_updates': 2, + 'qLD_level_offsets': 1, 'qLD_updates': -1, 'qM_fullm_i': 1, 'qM_fullm_j': 1, @@ -856,6 +903,9 @@ def _from_elt(cont, axis_size, d, axis_dest): 'stat__meaninertia': 1, 'taxel_sensorid': 1, 'taxel_vertadr': 1, + 'ten_J_colind': 1, + 'ten_J_rowadr': 1, + 'ten_J_rownnz': 1, 'ten_wrapadr_site': 1, 'ten_wrapnum_site': 1, 'tendon_actfrclimited': 1, @@ -940,6 +990,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'contact__dim': False, 'contact__dist': False, 'contact__efc_address': False, + 'contact__flex': False, 'contact__frame': False, 'contact__friction': False, 'contact__geom': False, @@ -950,6 +1001,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'contact__solref': False, 'contact__solreffriction': False, 'contact__type': False, + 'contact__vert': False, 'contact__worldid': False, 'crb': True, 'ctrl': True, @@ -993,6 +1045,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'nf': True, 'nisland': True, 'njmax': False, + 'njmax_nnz': False, 'njmax_pad': False, 'nl': True, 'nworld': False, @@ -1077,7 +1130,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'block_dim__qderiv_actuator_dense': False, 'block_dim__ray': False, 'block_dim__segmented_sort': False, - 'block_dim__tendon_velocity': False, + 'block_dim__solve_LD_sparse_fused': False, 'block_dim__update_gradient_JTDAJ_dense': False, 'block_dim__update_gradient_JTDAJ_sparse': False, 'block_dim__update_gradient_cholesky': False, @@ -1150,6 +1203,10 @@ def _from_elt(cont, axis_size, d, axis_dest): 'eq_wld_adr': False, 'exclude_signature': False, 'flex_bending': False, + 'flex_centered': False, + 'flex_conaffinity': False, + 'flex_condim': False, + 'flex_contype': False, 'flex_damping': False, 'flex_dim': False, 'flex_edge': False, @@ -1158,12 +1215,21 @@ def _from_elt(cont, axis_size, d, axis_dest): 'flex_edgenum': False, 'flex_elem': False, 'flex_elemadr': False, + 'flex_elemdataadr': False, 'flex_elemedge': False, 'flex_elemedgeadr': False, 'flex_elemnum': False, + 'flex_friction': False, + 'flex_margin': False, + 'flex_radius': False, + 'flex_shell': False, + 'flex_shelldataadr': False, + 'flex_shellnum': False, 'flex_stiffness': False, + 'flex_vert': False, 'flex_vertadr': False, 'flex_vertbodyid': False, + 'flex_vertflexid': False, 'flex_vertnum': False, 'flexedge_J_colind': False, 'flexedge_J_rowadr': False, @@ -1234,6 +1300,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'mat_rgba': True, 'mat_texid': True, 'mat_texrepeat': True, + 'max_ten_J_rownnz': False, 'mesh_face': False, 'mesh_faceadr': False, 'mesh_graph': False, @@ -1259,6 +1326,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'nC': False, 'nJfe': False, 'nJmom': False, + 'nJten': False, 'nM': False, 'na': False, 'nacttrnbody': False, @@ -1272,6 +1340,7 @@ def _from_elt(cont, axis_size, d, axis_dest): 'nflexelem': False, 'nflexelemdata': False, 'nflexelemedge': False, + 'nflexshelldata': False, 'nflexvert': False, 'ngeom': False, 'ngravcomp': False, @@ -1353,6 +1422,8 @@ def _from_elt(cont, axis_size, d, axis_dest): 'pair_solreffriction': True, 'plugin': False, 'plugin_attr': False, + 'qLD_all_updates': False, + 'qLD_level_offsets': False, 'qLD_updates': False, 'qM_fullm_i': False, 'qM_fullm_j': False, @@ -1398,6 +1469,9 @@ def _from_elt(cont, axis_size, d, axis_dest): 'stat__meaninertia': True, 'taxel_sensorid': False, 'taxel_vertadr': False, + 'ten_J_colind': False, + 'ten_J_rowadr': False, + 'ten_J_rownnz': False, 'ten_wrapadr_site': False, 'ten_wrapnum_site': False, 'tendon_actfrclimited': False, diff --git a/mjx/pyproject.toml b/mjx/pyproject.toml index ff7133cdd6..60063b70f9 100644 --- a/mjx/pyproject.toml +++ b/mjx/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ [project.optional-dependencies] warp = [ - "warp-lang==1.11.1", + "warp-lang==1.12.0", ] [project.scripts] diff --git a/plugin/actuator/register.cc b/plugin/actuator/register.cc index 66f24b9364..e3c3e2f54c 100644 --- a/plugin/actuator/register.cc +++ b/plugin/actuator/register.cc @@ -17,6 +17,6 @@ namespace mujoco::plugin::actuator { -mjPLUGIN_LIB_INIT { Pid::RegisterPlugin(); } +mjPLUGIN_LIB_INIT(actuator) { Pid::RegisterPlugin(); } } // namespace mujoco::plugin::actuator diff --git a/plugin/elasticity/register.cc b/plugin/elasticity/register.cc index ab8d283a60..b4a68754a6 100644 --- a/plugin/elasticity/register.cc +++ b/plugin/elasticity/register.cc @@ -17,7 +17,7 @@ namespace mujoco::plugin::elasticity { -mjPLUGIN_LIB_INIT { +mjPLUGIN_LIB_INIT(elasticity) { Cable::RegisterPlugin(); } diff --git a/plugin/obj_decoder/CMakeLists.txt b/plugin/obj_decoder/CMakeLists.txt index 17b671997c..689dfc51eb 100644 --- a/plugin/obj_decoder/CMakeLists.txt +++ b/plugin/obj_decoder/CMakeLists.txt @@ -12,48 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(EMSCRIPTEN) - add_library(obj_decoder OBJECT obj_decoder.cc) -else() - set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) - set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_LIBDIR}") - - add_library(obj_decoder SHARED obj_decoder.cc) -endif() - -target_include_directories(obj_decoder PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/../.. - ${CMAKE_CURRENT_SOURCE_DIR}/../../include +target_compile_definitions(mujoco PRIVATE TINYOBJLOADER_IMPLEMENTATION) +target_sources(mujoco PRIVATE + obj_decoder.cc ) -if(EMSCRIPTEN) - target_link_libraries(obj_decoder PRIVATE - tinyobjloader - ) -else() - target_link_libraries(obj_decoder PRIVATE - mujoco - tinyobjloader - ) -endif() - -target_compile_definitions(obj_decoder PRIVATE TINYOBJLOADER_IMPLEMENTATION) - -target_compile_options(obj_decoder PRIVATE - ${AVX_COMPILE_OPTIONS} - ${MUJOCO_MACOS_COMPILE_OPTIONS} - ${EXTRA_COMPILE_OPTIONS} - ${MUJOCO_CXX_FLAGS} -) - -if(NOT EMSCRIPTEN) - target_link_options(obj_decoder PRIVATE - ${MUJOCO_MACOS_LINK_OPTIONS} - ${EXTRA_LINK_OPTIONS} - ) - - install( - TARGETS obj_decoder - LIBRARY DESTINATION "${CMAKE_INSTALL_BINDIR}/mujoco_plugin" - ) -endif() +target_link_libraries(mujoco PRIVATE tinyobjloader) diff --git a/plugin/obj_decoder/obj_decoder.cc b/plugin/obj_decoder/obj_decoder.cc index 18cfdf5ff7..929e959ac0 100644 --- a/plugin/obj_decoder/obj_decoder.cc +++ b/plugin/obj_decoder/obj_decoder.cc @@ -114,7 +114,7 @@ int CanDecode(const mjResource* resource) { } // namespace -mjPLUGIN_LIB_INIT { +mjPLUGIN_LIB_INIT(obj_decoder) { mjpDecoder decoder; mjp_defaultDecoder(&decoder); decoder.content_type = "model/obj"; diff --git a/plugin/sdf/register.cc b/plugin/sdf/register.cc index 4e727e38e8..591467237d 100644 --- a/plugin/sdf/register.cc +++ b/plugin/sdf/register.cc @@ -20,7 +20,7 @@ namespace mujoco::plugin::sdf { -mjPLUGIN_LIB_INIT { +mjPLUGIN_LIB_INIT(sdf) { Bolt::RegisterPlugin(); Bowl::RegisterPlugin(); Gear::RegisterPlugin(); diff --git a/plugin/sensor/register.cc b/plugin/sensor/register.cc index dd8a70d8d1..b3a587afff 100644 --- a/plugin/sensor/register.cc +++ b/plugin/sensor/register.cc @@ -17,7 +17,7 @@ namespace mujoco::plugin::sensor { -mjPLUGIN_LIB_INIT { +mjPLUGIN_LIB_INIT(sensor) { TouchGrid::RegisterPlugin(); } diff --git a/plugin/stl_decoder/CMakeLists.txt b/plugin/stl_decoder/CMakeLists.txt index 1d383632b0..50eef18d98 100644 --- a/plugin/stl_decoder/CMakeLists.txt +++ b/plugin/stl_decoder/CMakeLists.txt @@ -12,43 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(EMSCRIPTEN) - add_library(stl_decoder OBJECT stl_decoder.cc) -else() - set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) - set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_LIBDIR}") - - add_library(stl_decoder SHARED stl_decoder.cc) -endif() - -target_include_directories(stl_decoder PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/../.. - ${CMAKE_CURRENT_SOURCE_DIR}/../../include +target_sources(mujoco PRIVATE + stl_decoder.cc ) - -if(EMSCRIPTEN) - target_link_libraries(stl_decoder PRIVATE) -else() - target_link_libraries(stl_decoder PRIVATE - mujoco - ) -endif() - -target_compile_options(stl_decoder PRIVATE - ${AVX_COMPILE_OPTIONS} - ${MUJOCO_MACOS_COMPILE_OPTIONS} - ${EXTRA_COMPILE_OPTIONS} - ${MUJOCO_CXX_FLAGS} -) - -if(NOT EMSCRIPTEN) - target_link_options(stl_decoder PRIVATE - ${MUJOCO_MACOS_LINK_OPTIONS} - ${EXTRA_LINK_OPTIONS} - ) - - install( - TARGETS stl_decoder - LIBRARY DESTINATION "${CMAKE_INSTALL_BINDIR}/mujoco_plugin" - ) -endif() diff --git a/plugin/stl_decoder/stl_decoder.cc b/plugin/stl_decoder/stl_decoder.cc index c6932d9ab5..7e1d1a5a69 100644 --- a/plugin/stl_decoder/stl_decoder.cc +++ b/plugin/stl_decoder/stl_decoder.cc @@ -132,7 +132,7 @@ int CanDecode(const mjResource* resource) { } // namespace -mjPLUGIN_LIB_INIT { +mjPLUGIN_LIB_INIT(stl_decoder) { mjpDecoder decoder; mjp_defaultDecoder(&decoder); decoder.content_type = "model/stl"; diff --git a/plugin/usd_decoder/usd_decoder.cc b/plugin/usd_decoder/usd_decoder.cc index 58a24590e4..f1ab1c0517 100644 --- a/plugin/usd_decoder/usd_decoder.cc +++ b/plugin/usd_decoder/usd_decoder.cc @@ -2459,7 +2459,7 @@ int CanDecode(const mjResource* resource) { } // namespace // clang-format off -mjPLUGIN_LIB_INIT { +mjPLUGIN_LIB_INIT(usd_decoder) { mjpDecoder decoder; mjp_defaultDecoder(&decoder); decoder.content_type = "model/usd"; diff --git a/python/build_requirements.txt b/python/build_requirements.txt index 49b3750f82..9be02aabac 100644 --- a/python/build_requirements.txt +++ b/python/build_requirements.txt @@ -44,8 +44,9 @@ numpy==2.1.3; python_version >= '3.10' \ --hash=sha256:825656d0743699c529c5943554d223c021ff0494ff1442152ce887ef4f7561a1 \ --hash=sha256:b47fbb433d3260adcd51eb54f92a2ffbc90a4595f8970ee00e064c644ac788f5 \ --hash=sha256:c894b4305373b9c5576d7a12b473702afdf48ce5369c074ba304cc5ad8730dff -pip==24.3.1 \ - --hash=sha256:3790624780082365f47549d032f3770eeb2b1e8bd1f7b2e02dace1afa361b4ed +pip==26.0 \ + --hash=sha256:3ce220a0a17915972fbf1ab451baae1521c4539e778b28127efa79b974aff0fa \ + --hash=sha256:98436feffb9e31bc9339cf369fd55d3331b1580b6a6f1173bacacddcf9c34754 PyOpenGL==3.1.7 \ --hash=sha256:a6ab19cf290df6101aaf7470843a9c46207789855746399d0af92521a0a92b7a pytest==8.3.3 \ diff --git a/python/make_sdist_requirements.txt b/python/make_sdist_requirements.txt index da33bbb9ca..cd3c177453 100644 --- a/python/make_sdist_requirements.txt +++ b/python/make_sdist_requirements.txt @@ -2,8 +2,9 @@ absl-py==2.1.0 \ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 build==1.2.2.post1 \ --hash=sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5 -pip==24.3.1 \ - --hash=sha256:3790624780082365f47549d032f3770eeb2b1e8bd1f7b2e02dace1afa361b4ed +pip==26.0 \ + --hash=sha256:3ce220a0a17915972fbf1ab451baae1521c4539e778b28127efa79b974aff0fa \ + --hash=sha256:98436feffb9e31bc9339cf369fd55d3331b1580b6a6f1173bacacddcf9c34754 setuptools==78.1.1 \ --hash=sha256:c3a9c4211ff4c309edb8b8c4f1cbfa7ae324c4ba9f91ff254e3d305b9fd54561 \ --hash=sha256:fcc17fd9cd898242f6b4adfaca46137a9edef687f43e6f78469692a5e70d851d diff --git a/python/mujoco/introspect/enums.py b/python/mujoco/introspect/enums.py index d4dcdf34d9..bc28f78d6d 100644 --- a/python/mujoco/introspect/enums.py +++ b/python/mujoco/introspect/enums.py @@ -263,7 +263,8 @@ ('mjDYN_FILTER', 2), ('mjDYN_FILTEREXACT', 3), ('mjDYN_MUSCLE', 4), - ('mjDYN_USER', 5), + ('mjDYN_DCMOTOR', 5), + ('mjDYN_USER', 6), ]), )), ('mjtGain', @@ -274,7 +275,8 @@ ('mjGAIN_FIXED', 0), ('mjGAIN_AFFINE', 1), ('mjGAIN_MUSCLE', 2), - ('mjGAIN_USER', 3), + ('mjGAIN_DCMOTOR', 3), + ('mjGAIN_USER', 4), ]), )), ('mjtBias', @@ -285,7 +287,8 @@ ('mjBIAS_NONE', 0), ('mjBIAS_AFFINE', 1), ('mjBIAS_MUSCLE', 2), - ('mjBIAS_USER', 3), + ('mjBIAS_DCMOTOR', 3), + ('mjBIAS_USER', 4), ]), )), ('mjtObj', diff --git a/python/mujoco/introspect/functions.py b/python/mujoco/introspect/functions.py index ff1bcf125d..0ab6df20f4 100644 --- a/python/mujoco/introspect/functions.py +++ b/python/mujoco/introspect/functions.py @@ -10786,6 +10786,86 @@ ), doc='Set actuator to active adhesion; return error if any.', )), + ('mjs_setToDCMotor', + FunctionDecl( + name='mjs_setToDCMotor', + return_type=PointerType( + inner_type=ValueType(name='char', is_const=True), + ), + parameters=( + FunctionParameterDecl( + name='actuator', + type=PointerType( + inner_type=ValueType(name='mjsActuator'), + ), + ), + FunctionParameterDecl( + name='motorconst', + type=ArrayType( + inner_type=ValueType(name='double'), + extents=(2,), + ), + ), + FunctionParameterDecl( + name='resistance', + type=ValueType(name='double'), + ), + FunctionParameterDecl( + name='nominal', + type=ArrayType( + inner_type=ValueType(name='double'), + extents=(3,), + ), + ), + FunctionParameterDecl( + name='saturation', + type=ArrayType( + inner_type=ValueType(name='double'), + extents=(4,), + ), + ), + FunctionParameterDecl( + name='inductance', + type=ArrayType( + inner_type=ValueType(name='double'), + extents=(2,), + ), + ), + FunctionParameterDecl( + name='cogging', + type=ArrayType( + inner_type=ValueType(name='double'), + extents=(3,), + ), + ), + FunctionParameterDecl( + name='controller', + type=ArrayType( + inner_type=ValueType(name='double'), + extents=(5,), + ), + ), + FunctionParameterDecl( + name='thermal', + type=ArrayType( + inner_type=ValueType(name='double'), + extents=(6,), + ), + ), + FunctionParameterDecl( + name='lugre', + type=ArrayType( + inner_type=ValueType(name='double'), + extents=(6,), + ), + ), + FunctionParameterDecl( + name='input_mode', + type=ValueType(name='int'), + ), + ), + doc='Set actuator to DC motor; return error if any.', + )), ('mjs_addMesh', FunctionDecl( name='mjs_addMesh', diff --git a/python/mujoco/specs.cc b/python/mujoco/specs.cc index e3bfc7ef9a..28af8af17f 100644 --- a/python/mujoco/specs.cc +++ b/python/mujoco/specs.cc @@ -1303,6 +1303,31 @@ PYBIND11_MODULE(_specs, m) { } }, py::arg("gain")); + mjsActuator.def( + "set_to_dcmotor", + [](raw::MjsActuator* self, std::array motorconst, + double resistance, + std::array nominal, std::array saturation, + std::array inductance, std::array cogging, + std::array controller, std::array thermal, + std::array lugre, int input_mode) { + std::string err = mjs_setToDCMotor( + self, motorconst.data(), resistance, nominal.data(), + saturation.data(), inductance.data(), cogging.data(), + controller.data(), thermal.data(), lugre.data(), input_mode); + if (!err.empty()) { + throw pybind11::value_error(err); + } + }, + py::arg("motorconst"), py::arg("resistance"), + py::arg("nominal") = std::array{0, 0, 0}, + py::arg("saturation") = std::array{0, 0, 0, 0}, + py::arg("inductance") = std::array{0, 0}, + py::arg("cogging") = std::array{0, 0, 0}, + py::arg("controller") = std::array{0, 0, 0, 0, 0}, + py::arg("thermal") = std::array{0, 0, 0, 0, 0, 0}, + py::arg("lugre") = std::array{0, 0, 0, 0, 0, 0}, + py::arg("input_mode") = 0); // ============================= MJSTENDONPATH =============================== // helper struct for tendon path indexing diff --git a/python/mujoco/specs_test.py b/python/mujoco/specs_test.py index f1c99d823c..a82646dde4 100644 --- a/python/mujoco/specs_test.py +++ b/python/mujoco/specs_test.py @@ -1557,6 +1557,13 @@ def test_actuator_shortname(self): self.assertEqual(actuator.gaintype, mujoco.mjtGain.mjGAIN_FIXED) self.assertEqual(actuator.biastype, mujoco.mjtBias.mjBIAS_NONE) + actuator.set_to_dcmotor(motorconst=[0.05, 0.05], resistance=2.0) + self.assertEqual(actuator.gainprm[0], 2.0) + self.assertEqual(actuator.gainprm[1], 0.05) + self.assertEqual(actuator.dyntype, mujoco.mjtDyn.mjDYN_DCMOTOR) + self.assertEqual(actuator.gaintype, mujoco.mjtGain.mjGAIN_DCMOTOR) + self.assertEqual(actuator.biastype, mujoco.mjtBias.mjBIAS_DCMOTOR) + def test_bad_contact_sensor(self): test_cases = [ dict( diff --git a/sample/testspeed.cc b/sample/testspeed.cc index 544edfc613..f14e657212 100644 --- a/sample/testspeed.cc +++ b/sample/testspeed.cc @@ -187,12 +187,6 @@ int main(int argc, char** argv) { nthread = mjMAX(1, mjMIN(maxthread, nthread)); npoolthread = mjMAX(1, mjMIN(maxthread, npoolthread)); - // load plugins from MUJOCO_PLUGIN_DIR if set - const char* plugin_dir = std::getenv("MUJOCO_PLUGIN_DIR"); - if (plugin_dir) { - mj_loadAllPluginLibraries(plugin_dir, nullptr); - } - // get filename, determine file type std::string filename(argv[1]); bool binary = (filename.find(".mjb") != std::string::npos); // NOLINT diff --git a/src/engine/engine_derivative.c b/src/engine/engine_derivative.c index 3f9b467a56..00267a3e9e 100644 --- a/src/engine/engine_derivative.c +++ b/src/engine/engine_derivative.c @@ -1107,6 +1107,17 @@ void mjd_actuator_vel(const mjModel* m, mjData* d) { bias_vel = (m->actuator_biasprm + mjNBIAS*i)[2]; } + // DC motor bias (back-EMF) + else if (m->actuator_biastype[i] == mjBIAS_DCMOTOR) { + const mjtNum* dynprm = m->actuator_dynprm + mjNDYN*i; + const mjtNum* gainprm = m->actuator_gainprm + mjNGAIN*i; + if (dynprm[0] <= 0) { + mjtNum R = mju_max(mjMINVAL, gainprm[0]); + mjtNum K = gainprm[1]; + bias_vel -= K * K / R; + } + } + // affine gain if (m->actuator_gaintype[i] == mjGAIN_AFFINE) { // extract bias info: prm = [const, kp, kv] @@ -1122,6 +1133,28 @@ void mjd_actuator_vel(const mjModel* m, mjData* d) { m->actuator_gainprm + mjNGAIN*i); } + // DC motor controller damping and LuGre micro-damping + else if (m->actuator_gaintype[i] == mjGAIN_DCMOTOR) { + const mjtNum* dynprm = m->actuator_dynprm + mjNDYN*i; + const mjtNum* gainprm = m->actuator_gainprm + mjNGAIN*i; + int input_mode = (int)gainprm[8]; + if (input_mode > 0) { + mjtNum R = gainprm[0]; + mjtNum K = gainprm[1]; + mjtNum gain = (dynprm[0] > 0) ? K : K / mju_max(mjMINVAL, R); + mjtNum kp = gainprm[4]; + mjtNum kd = gainprm[6]; + bias_vel -= gain * (input_mode == 1 ? kd : kp); + } + + // LuGre: force includes -sigma1*z_dot, z_dot = a*z + v + // d(sigma1*z_dot)/dv = sigma1*(da/dv*z + 1), ignoring higher-order da/dv*z + mjtNum sigma1 = dynprm[6]; + if (sigma1 > 0) { + bias_vel -= sigma1; + } + } + // force = gain .* [ctrl/act] if (gain_vel != 0) { if (m->actuator_dyntype[i] == mjDYN_NONE) { diff --git a/src/engine/engine_forward.c b/src/engine/engine_forward.c index b28c342feb..e8e08f009b 100644 --- a/src/engine/engine_forward.c +++ b/src/engine/engine_forward.c @@ -257,6 +257,36 @@ void mj_fwdVelocity(const mjModel* m, mjData* d) { } +// helper for DC motor: computes control voltage from PID state +static mjtNum dcmotorVoltage(mjtNum ctrl, mjtNum length, mjtNum velocity, + mjtNum x_I, const mjtNum* gainprm) { + int input_mode = (int)gainprm[8]; + mjtNum Vmax = gainprm[7]; + mjtNum voltage; + + // get voltage + if (input_mode > 0) { + mjtNum kp = gainprm[4]; // proportional gain + mjtNum ki = gainprm[5]; // integral gain + mjtNum kd = gainprm[6]; // derivative gain + + if (input_mode == 1) { + // position mode + voltage = kp * (ctrl - length) + ki * x_I - kd * velocity; + } else { + // velocity mode + voltage = kp * (ctrl - velocity) + ki * (x_I - length); + } + } else { + voltage = ctrl; + } + + // clip voltage + if (Vmax > 0) voltage = mju_clip(voltage, -Vmax, Vmax); + + return voltage; +} + // clamp vector to range static void clampVec(mjtNum* vec, const mjtNum* range, const mjtByte* limited, int n, @@ -275,7 +305,7 @@ void mj_fwdActuation(const mjModel* m, mjData* d) { TM_START; int nv = m->nv, nu = m->nu, ntendon = m->ntendon; mjtNum gain, bias, tau; - mjtNum *prm, *force = d->actuator_force; + mjtNum *force = d->actuator_force; // clear actuator_force mju_zero(force, nu); @@ -327,37 +357,136 @@ void mj_fwdActuation(const mjModel* m, mjData* d) { } // zero act_dot for actuator plugins - if (m->actuator_actnum[i]) { - mju_zero(d->act_dot + act_first, m->actuator_actnum[i]); + int actnum = m->actuator_actnum[i]; + if (actnum) { + mju_zero(d->act_dot + act_first, actnum); } // extract info - prm = m->actuator_dynprm + i*mjNDYN; + const mjtNum* dynprm = m->actuator_dynprm + i*mjNDYN; + mjtDyn dyntype = m->actuator_dyntype[i]; // index into the last element in act. For most actuators it's also the - // first element, but actuator plugins might store their own state in act. - int act_last = act_first + m->actuator_actnum[i] - 1; + // first element, but actuator plugins might store their own state in act + int act_last = act_first + actnum - 1; // compute act_dot according to dynamics type - switch ((mjtDyn) m->actuator_dyntype[i]) { + switch (dyntype) { case mjDYN_INTEGRATOR: // simple integrator d->act_dot[act_last] = ctrl[i]; break; - case mjDYN_FILTER: // linear filter: prm = tau + case mjDYN_FILTER: // linear filter: dynprm = tau case mjDYN_FILTEREXACT: - tau = mju_max(mjMINVAL, prm[0]); + tau = mju_max(mjMINVAL, dynprm[0]); d->act_dot[act_last] = (ctrl[i] - d->act[act_last]) / tau; break; - case mjDYN_MUSCLE: // muscle model: prm = (tau_act, tau_deact) - d->act_dot[act_last] = mju_muscleDynamics( - ctrl[i], d->act[act_last], prm); + case mjDYN_MUSCLE: // muscle model: dynprm = (tau_act, tau_deact) + d->act_dot[act_last] = mju_muscleDynamics(ctrl[i], d->act[act_last], dynprm); break; + case mjDYN_DCMOTOR: { // DC motor: up to 5 optional states + const mjtNum* gainprm = m->actuator_gainprm + mjNGAIN*i; + + // verify allocated state size matches parameters; SHOULD NOT OCCUR + if (mj_dcmotorSlots(dynprm, gainprm).num_slots != actnum) { + mjERROR("inconsistent state array dimension in DC motor (actuator %d)", i); + } + + int adr = act_first; + mjtNum velocity = d->actuator_velocity[i]; + mjtNum R = gainprm[0]; // resistance + mjtNum K = gainprm[1]; // motor constant + mjtNum ki = gainprm[5]; // integral gain + mjtNum te = dynprm[0]; // electrical time constant + + // slot order: slew, integral, temperature, bristle, current + + // controller state: slew rate limiting + mjtNum slew_s = dynprm[7]; // slew rate limit + if (slew_s > 0) { + mjtNum u_prev = d->act[adr]; + mjtNum slew = slew_s * m->opt.timestep; + mjtNum u_eff = mju_clip(ctrl[i], u_prev - slew, u_prev + slew); + d->act_dot[adr] = (u_eff - u_prev) / m->opt.timestep; + ctrl[i] = u_eff; + adr++; + } + + // controller state: integral state + mjtNum x_I = 0; + if (ki > 0) { + x_I = d->act[adr]; + int input_mode = (int)gainprm[8]; + mjtNum Imax = dynprm[8]; // integral clamp + mjtNum act_dot = ctrl[i]; // default raw accumulator for voltage and velocity modes + + // position mode + if (input_mode == 1) { + act_dot = ctrl[i] - d->actuator_length[i]; + } + + // clamp act_dot based on integral state + if (Imax > 0) { + if (x_I >= Imax) { + act_dot = mju_min(act_dot, 0); + } else if (x_I <= -Imax) { + act_dot = mju_max(act_dot, 0); + } + } + d->act_dot[adr] = act_dot; + adr++; + } + + // compute physical voltage to feed into current and temperature equations + mjtNum V = dcmotorVoltage(ctrl[i], d->actuator_length[i], velocity, x_I, gainprm); + + // temperature: dT/dt = (R*i^2 - T/RT) / C, where T = delta above ambient + mjtNum RT = dynprm[2]; // thermal resistance + if (RT > 0) { + mjtNum C = dynprm[3]; // thermal capacitance + mjtNum Ta = dynprm[4]; // ambient temperature + mjtNum alpha = gainprm[2]; // temperature coefficient + mjtNum T0 = gainprm[3]; // reference temperature + mjtNum T = d->act[adr]; // temperature rise above ambient + R *= 1 + alpha * (T + Ta - T0); + + // get current: from act_last if stateful, from (V - K*omega)/R if stateless + mjtNum current = (te > 0) ? d->act[act_last] : (V - K * velocity) / R; + d->act_dot[adr] = (R*current*current - T / RT) / C; + adr++; + } + + // LuGre bristle state: dz/dt = v - sigma0 * |v| / g(v) * z + mjtNum sigma0 = dynprm[5]; // bristle stiffness + if (sigma0 > 0) { + const mjtNum* biasprm = m->actuator_biasprm + mjNBIAS*i; + mjtNum F_C = biasprm[3]; // Coulomb friction + mjtNum F_S = biasprm[4]; // static friction + mjtNum v_S = biasprm[5]; // Stribeck velocity + mjtNum z = d->act[adr]; // bristle state + mjtNum g = mj_lugreStribeck(velocity, F_C, F_S, v_S); + mjtNum a = -sigma0 * mju_abs(velocity) / mju_max(mjMINVAL, g); + d->act_dot[adr] = a * z + velocity; + adr++; + } + + // current state: di/dt = (V/R - K/R*omega - i) / te + if (te > 0) { + mjtNum dimax = dynprm[1]; // current rate limit (di/dt)_max + mjtNum i_dot = (V/R - K/R*velocity - d->act[act_last]) / te; + if (dimax > 0) { + i_dot = mju_clip(i_dot, -dimax, dimax); + } + d->act_dot[act_last] = i_dot; + } + break; + } + default: // user dynamics if (mjcb_act_dyn) { - if (m->actuator_actnum[i] == 1) { + if (actnum == 1) { // scalar activation dynamics, get act_dot d->act_dot[act_last] = mjcb_act_dyn(m, d, i); } else { @@ -407,17 +536,20 @@ void mj_fwdActuation(const mjModel* m, mjData* d) { tendon_frclimited = m->tendon_actfrclimited[m->actuator_trnid[2*i]]; } - // extract gain info - prm = m->actuator_gainprm + mjNGAIN*i; + // extract info + const mjtNum* dynprm = m->actuator_dynprm + mjNDYN*i; + const mjtNum* gainprm = m->actuator_gainprm + mjNGAIN*i; + mjtGain gaintype = m->actuator_gaintype[i]; + int actnum = m->actuator_actnum[i]; // handle according to gain type - switch ((mjtGain) m->actuator_gaintype[i]) { + switch (gaintype) { case mjGAIN_FIXED: // fixed gain: prm = gain - gain = prm[0]; + gain = gainprm[0]; break; case mjGAIN_AFFINE: // affine: prm = [const, kp, kv] - gain = prm[0] + prm[1]*d->actuator_length[i] + prm[2]*d->actuator_velocity[i]; + gain = gainprm[0] + gainprm[1]*d->actuator_length[i] + gainprm[2]*d->actuator_velocity[i]; break; case mjGAIN_MUSCLE: // muscle gain @@ -425,9 +557,43 @@ void mj_fwdActuation(const mjModel* m, mjData* d) { d->actuator_velocity[i], m->actuator_lengthrange+2*i, m->actuator_acc0[i], - prm); + gainprm); break; + case mjGAIN_DCMOTOR: { // DC motor: gain = K or K/R + mjtNum R = gainprm[0]; // resistance + mjtNum K = gainprm[1]; // motor constant + mjDCMotorSlots slots = mj_dcmotorSlots(dynprm, gainprm); + + // verify allocated state size matches parameters; SHOULD NOT OCCUR + if (slots.num_slots != actnum) { + mjERROR("inconsistent state array dimension in DC motor (actuator %d)", i); + } + + int adr = m->actuator_actadr[i]; + + // adjust R for temperature if enabled + if (slots.temperature >= 0) { + mjtNum T = d->act[adr + slots.temperature]; + mjtNum alpha = gainprm[2]; // temperature coefficient + mjtNum T0 = gainprm[3]; // reference temperature + mjtNum Ta = dynprm[4]; // ambient temperature + R *= 1 + alpha * (T + Ta - T0); + } + + // stateful current: gain = K, force = K * act[last] (generic path) + // stateless: gain = K/R, force = K/R * ctrl (condition below) + gain = (dynprm[0] > 0) ? K : K / mju_max(mjMINVAL, R); + + // controller: compute voltage, override ctrl[i] for force computation + if ((int)gainprm[8] > 0) { + mjtNum x_I = (slots.integral >= 0) ? d->act[adr + slots.integral] : 0; + ctrl[i] = dcmotorVoltage(ctrl[i], d->actuator_length[i], + d->actuator_velocity[i], x_I, gainprm); + } + break; + } + default: // user gain if (mjcb_act_gain) { gain = mjcb_act_gain(m, d, i); @@ -437,11 +603,14 @@ void mj_fwdActuation(const mjModel* m, mjData* d) { } // set force = gain .* [ctrl/act] - if (m->actuator_actadr[i] == -1) { + + // DC motor without current state: use ctrl even if other activations exist + int dcmotor_no_current = (gaintype == mjGAIN_DCMOTOR && dynprm[0] <= 0); + if (actnum == 0 || dcmotor_no_current) { force[i] = gain * ctrl[i]; } else { // use last activation variable associated with actuator i - int act_adr = m->actuator_actadr[i] + m->actuator_actnum[i] - 1; + int act_adr = m->actuator_actadr[i] + actnum - 1; mjtNum act; if (m->actuator_actearly[i]) { @@ -453,25 +622,38 @@ void mj_fwdActuation(const mjModel* m, mjData* d) { } // extract bias info - prm = m->actuator_biasprm + mjNBIAS*i; + const mjtNum* biasprm = m->actuator_biasprm + mjNBIAS*i; + mjtBias biastype = m->actuator_biastype[i]; // handle according to bias type - switch ((mjtBias) m->actuator_biastype[i]) { + switch (biastype) { case mjBIAS_NONE: // none bias = 0.0; break; - case mjBIAS_AFFINE: // affine: prm = [const, kp, kv] - bias = prm[0] + prm[1]*d->actuator_length[i] + prm[2]*d->actuator_velocity[i]; + case mjBIAS_AFFINE: // affine: biasprm = [const, kp, kv] + bias = biasprm[0] + biasprm[1]*d->actuator_length[i] + biasprm[2]*d->actuator_velocity[i]; break; case mjBIAS_MUSCLE: // muscle passive force bias = mju_muscleBias(d->actuator_length[i], m->actuator_lengthrange+2*i, m->actuator_acc0[i], - prm); + biasprm); break; + case mjBIAS_DCMOTOR: { // DC motor: back-EMF only (current-limited) + bias = 0; + + // back-EMF (stateless only; for stateful current it's in the ODE) + mjtNum te = m->actuator_dynprm[mjNDYN*i]; // electrical time constant + if (te <= 0) { + mjtNum K = gainprm[1]; // motor constant + bias -= gain * K * d->actuator_velocity[i]; + } + break; + } + default: // user bias if (mjcb_act_bias) { bias = mjcb_act_bias(m, d, i); @@ -537,6 +719,41 @@ void mj_fwdActuation(const mjModel* m, mjData* d) { // clamp actuator_force clampVec(force, m->actuator_forcerange, m->actuator_forcelimited, nu, NULL); + // add DC motor mechanical forces (not subject to current limits) + for (int i=0; i < nu; i++) { + if (m->actuator_biastype[i] != mjBIAS_DCMOTOR) { + continue; + } + if (sleep_filter && mj_sleepState(m, d, mjOBJ_ACTUATOR, i) == mjS_ASLEEP) { + continue; + } + if (mj_actuatorDisabled(m, i) || m->actuator_plugin[i] >= 0) { + continue; + } + + const mjtNum* biasprm = m->actuator_biasprm + mjNBIAS*i; + const mjtNum* dynprm = m->actuator_dynprm + mjNDYN*i; + + // cogging torque + mjtNum A = biasprm[0]; + if (A != 0) { + mjtNum Np = biasprm[1]; + mjtNum phi = biasprm[2]; + force[i] += A * mju_sin(Np*d->actuator_length[i] + phi); + } + + // LuGre friction + mjtNum sigma0 = dynprm[5]; + if (sigma0 > 0) { + mjtNum sigma1 = dynprm[6]; + mjDCMotorSlots slots = mj_dcmotorSlots(dynprm, m->actuator_gainprm + mjNGAIN*i); + int adr = m->actuator_actadr[i] + slots.bristle; + mjtNum z = d->act[adr]; + mjtNum z_dot = d->act_dot[adr]; + force[i] -= sigma0 * z + sigma1 * z_dot; + } + } + // qfrc_actuator = moment' * force mju_mulMatTVecSparse(d->qfrc_actuator, d->actuator_moment, force, nu, nv, d->moment_rownnz, d->moment_rowadr, d->moment_colind); diff --git a/src/engine/engine_support.c b/src/engine/engine_support.c index a4827ddb7c..81c3b6caa9 100644 --- a/src/engine/engine_support.c +++ b/src/engine/engine_support.c @@ -709,22 +709,69 @@ int mj_actuatorDisabled(const mjModel* m, int i) { mjtNum mj_nextActivation(const mjModel* m, const mjData* d, int actuator_id, int act_adr, mjtNum act_dot) { mjtNum act = d->act[act_adr]; + int dyntype = m->actuator_dyntype[actuator_id]; - if (m->actuator_dyntype[actuator_id] == mjDYN_FILTEREXACT) { + if (dyntype == mjDYN_FILTEREXACT) { // exact filter integration // act_dot(0) = (ctrl-act(0)) / tau // act(h) = act(0) + (ctrl-act(0)) (1 - exp(-h / tau)) // = act(0) + act_dot(0) * tau * (1 - exp(-h / tau)) mjtNum tau = mju_max(mjMINVAL, m->actuator_dynprm[actuator_id*mjNDYN]); act = act + act_dot * tau * (1 - mju_exp(-m->opt.timestep / tau)); - } else { - // Euler integration + } else if (dyntype == mjDYN_DCMOTOR) { + const mjtNum* dynprm = m->actuator_dynprm + actuator_id * mjNDYN; + const mjtNum* gainprm = m->actuator_gainprm + actuator_id * mjNGAIN; + mjDCMotorSlots slots = mj_dcmotorSlots(dynprm, gainprm); + + int offset = act_adr - m->actuator_actadr[actuator_id]; + + // current filter: exact integration + if (offset == slots.current) { + mjtNum te = mju_max(mjMINVAL, dynprm[0]); + act = act + act_dot * te * (1 - mju_exp(-m->opt.timestep / te)); + } + + // LuGre bristle: dz/dt = a*z + v where a = -sigma0*|v|/g(v) + else if (offset == slots.bristle) { + const mjtNum* biasprm = m->actuator_biasprm + mjNBIAS*actuator_id; + mjtNum F_C = biasprm[3]; // Coulomb friction + mjtNum F_S = biasprm[4]; // static friction + mjtNum v_S = biasprm[5]; // Stribeck velocity + mjtNum sigma0 = dynprm[5]; // bristle stiffness + mjtNum velocity = d->actuator_velocity[actuator_id]; + mjtNum g = mj_lugreStribeck(velocity, F_C, F_S, v_S); + + // ZOH exact ZOH integration: z(h) = exp(ah)*z(0) + ((exp(ah)-1)/a)*v + mjtNum a = -sigma0 * mju_abs(velocity) / mju_max(mjMINVAL, g); // decay rate + mjtNum h = m->opt.timestep; + mjtNum exp_ah = mju_exp(a * h); // state transition + mjtNum int_h = mju_abs(a) > mjMINVAL ? (exp_ah - 1) / a : h; // input integral + act = exp_ah * act + int_h * velocity; + } + + // integral state: Euler integration with anti-windup clamp + else if (offset == slots.integral) { + act = act + act_dot * m->opt.timestep; + mjtNum Imax = dynprm[8]; + if (Imax > 0) { + act = mju_clip(act, -Imax, Imax); + } + } + + // temperature and slew: Euler integration + else { + act = act + act_dot * m->opt.timestep; + } + } + + // otherwise Euler integration + else { act = act + act_dot * m->opt.timestep; } - // clamp to actrange - if (m->actuator_actlimited[actuator_id]) { - mjtNum* actrange = m->actuator_actrange + 2*actuator_id; + // clamp to actrange unless DC motor + if (dyntype != mjDYN_DCMOTOR && m->actuator_actlimited[actuator_id]) { + const mjtNum* actrange = m->actuator_actrange + 2*actuator_id; act = mju_clip(act, actrange[0], actrange[1]); } diff --git a/src/engine/engine_util_misc.c b/src/engine/engine_util_misc.c index 0dd2ca9753..65057a51e8 100644 --- a/src/engine/engine_util_misc.c +++ b/src/engine/engine_util_misc.c @@ -769,6 +769,26 @@ mjtNum mju_muscleDynamics(mjtNum ctrl, mjtNum act, const mjtNum prm[3]) { } +// LuGre Stribeck function: g(v) = F_C + (F_S - F_C) * exp(-(v/v_S)^2) +mjtNum mj_lugreStribeck(mjtNum velocity, mjtNum F_C, mjtNum F_S, mjtNum v_S) { + mjtNum ratio = velocity / mju_max(mjMINVAL, v_S); + return F_C + (F_S - F_C) * mju_exp(-ratio*ratio); +} + + +// compute DC motor activation slot indices from parameter arrays +mjDCMotorSlots mj_dcmotorSlots(const mjtNum* dynprm, const mjtNum* gainprm) { + mjDCMotorSlots s = {-1, -1, -1, -1, -1, 0}; + if (dynprm[7] > 0) s.slew = s.num_slots++; // slew rate limiting + if (gainprm[5] > 0) s.integral = s.num_slots++; // PI integral + if (dynprm[2] > 0) s.temperature = s.num_slots++; // thermal model + if (dynprm[5] > 0) s.bristle = s.num_slots++; // LuGre bristle + if (dynprm[0] > 0) s.current = s.num_slots++; // current filter + + return s; +} + + //---------------------------------------- Base64 -------------------------------------------------- // decoding function for Base64 diff --git a/src/engine/engine_util_misc.h b/src/engine/engine_util_misc.h index 756b3b8ebf..cac5bed59c 100644 --- a/src/engine/engine_util_misc.h +++ b/src/engine/engine_util_misc.h @@ -50,6 +50,23 @@ MJAPI mjtNum mju_muscleDynamicsTimescale(mjtNum dctrl, mjtNum tau_act, mjtNum ta // muscle activation dynamics, prm = (tau_act, tau_deact, smoothing_width) MJAPI mjtNum mju_muscleDynamics(mjtNum ctrl, mjtNum act, const mjtNum prm[3]); +// LuGre Stribeck function: g(v) = F_C + (F_S - F_C) * exp(-(v/v_S)^2) +mjtNum mj_lugreStribeck(mjtNum velocity, mjtNum F_C, mjtNum F_S, mjtNum v_S); + +// DC motor activation slot indices (-1 = slot not active) +typedef struct { + int slew; // slew rate state + int integral; // integral state + int temperature; // temperature state + int bristle; // LuGre bristle state + int current; // current state + int num_slots; // number of DC motor states +} mjDCMotorSlots; + +// compute activation slot indices for a DC motor actuator +// dynprm = actuator_dynprm row, gainprm = actuator_gainprm row +mjDCMotorSlots mj_dcmotorSlots(const mjtNum* dynprm, const mjtNum* gainprm); + // all 3 semi-axes of a geom MJAPI void mju_geomSemiAxes(mjtNum semiaxes[3], const mjtNum size[3], mjtGeom type); diff --git a/src/experimental/filament/CMakeLists.txt b/src/experimental/filament/CMakeLists.txt index 5c9242aa43..457a5db9d6 100644 --- a/src/experimental/filament/CMakeLists.txt +++ b/src/experimental/filament/CMakeLists.txt @@ -47,6 +47,8 @@ target_sources(${MUJOCO_FILAMENT_TARGET_NAME} filament/material.h filament/math_util.cc filament/math_util.h + filament/model_objects.cc + filament/model_objects.h filament/model_util.cc filament/model_util.h filament/object_manager.cc @@ -57,8 +59,8 @@ target_sources(${MUJOCO_FILAMENT_TARGET_NAME} filament/renderables.h filament/scene_view.cc filament/scene_view.h - filament/texture_util.cc - filament/texture_util.h + filament/texture.cc + filament/texture.h filament/vertex_util.cc filament/vertex_util.h ) diff --git a/src/experimental/filament/filament/buffer_util.h b/src/experimental/filament/filament/buffer_util.h index 4a2a11046d..b6167ff837 100644 --- a/src/experimental/filament/filament/buffer_util.h +++ b/src/experimental/filament/filament/buffer_util.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -33,7 +34,7 @@ namespace mujoco { struct FilamentBuffers { filament::IndexBuffer* index_buffer = nullptr; filament::VertexBuffer* vertex_buffer = nullptr; - filament::Box bounds = {{-1, -1, -1}, {1, 1, 1}}; + std::optional bounds = std::nullopt; filament::RenderableManager::PrimitiveType type = filament::RenderableManager::PrimitiveType::TRIANGLES; }; diff --git a/src/experimental/filament/filament/builtins.cc b/src/experimental/filament/filament/builtins.cc index ba2525a8cc..825aa95371 100644 --- a/src/experimental/filament/filament/builtins.cc +++ b/src/experimental/filament/filament/builtins.cc @@ -87,7 +87,7 @@ class LineBuilder { } filament::Box GetBounds() const { - return {{-0.001, -0.001, 0}, {0.001, 0.001, 1}}; + return filament::Box().set({-0.001, -0.001, 0}, {0.001, 0.001, 1}); } }; @@ -137,7 +137,9 @@ class PlaneBuilder { } } - filament::Box GetBounds() const { return {{-1, -1, -0.001}, {1, 1, 0.001}}; } + filament::Box GetBounds() const { + return filament::Box().set({-1, -1, -0.001}, {1, 1, 0.001}); + } private: int num_quads_per_axis_; @@ -174,7 +176,9 @@ class TriangleBuilder { ptr[2] = 2; } - filament::Box GetBounds() const { return {{-1, -1, -0.001}, {1, 1, 0.001}}; } + filament::Box GetBounds() const { + return filament::Box().set({-1, -1, -0.001}, {1, 1, 0.001}); + } private: float4 orientation_; @@ -239,7 +243,9 @@ class LineBoxBuilder { ptr[23] = 5; } - filament::Box GetBounds() const { return {{-1, -1, -1}, {1, 1, 1}}; } + filament::Box GetBounds() const { + return filament::Box().set({-1, -1, -1}, {1, 1, 1}); + } }; class BoxBuilder { @@ -305,7 +311,9 @@ class BoxBuilder { } } - filament::Box GetBounds() const { return {{-1, -1, -1}, {1, 1, 1}}; } + filament::Box GetBounds() const { + return filament::Box().set({-1, -1, -1}, {1, 1, 1}); + } private: template @@ -378,7 +386,9 @@ class TubeBuilder { } } - filament::Box GetBounds() const { return {{-1, -1, -1}, {1, 1, 1}}; } + filament::Box GetBounds() const { + return filament::Box().set({-1, -1, -1}, {1, 1, 1}); + } private: int num_stacks_; @@ -461,7 +471,9 @@ class ConeBuilder { } } - filament::Box GetBounds() const { return {{-1, -1, 0}, {1, 1, 1}}; } + filament::Box GetBounds() const { + return filament::Box().set({-1, -1, 0}, {1, 1, 1}); + } private: static VertexType MakeVert(float theta, float radius) { @@ -520,7 +532,9 @@ class DiskBuilder { } } - filament::Box GetBounds() const { return {{-1, -1, -0.001}, {1, 1, 0.001}}; } + filament::Box GetBounds() const { + return filament::Box().set({-1, -1, -0.001}, {1, 1, 0.001}); + } private: int num_slices_; @@ -625,7 +639,9 @@ class SphereBuilder { } } - filament::Box GetBounds() const { return {{-1, -1, -1}, {1, 1, 1}}; } + filament::Box GetBounds() const { + return filament::Box().set({-1, -1, -1}, {1, 1, 1}); + } private: static VertexType MakeVert(float x, float y, float z) { @@ -726,7 +742,9 @@ class DomeBuilder { } } - filament::Box GetBounds() const { return {{-1, -1, 0}, {1, 1, 1}}; } + filament::Box GetBounds() const { + return filament::Box().set({-1, -1, 0}, {1, 1, 1}); + } private: static VertexType MakeVert(float x, float y, float z) { diff --git a/src/experimental/filament/filament/drawable.cc b/src/experimental/filament/filament/drawable.cc index 8dd575bba6..f32afe1ae4 100644 --- a/src/experimental/filament/filament/drawable.cc +++ b/src/experimental/filament/filament/drawable.cc @@ -35,7 +35,9 @@ #include "experimental/filament/filament/geom_util.h" #include "experimental/filament/filament/material.h" #include "experimental/filament/filament/math_util.h" +#include "experimental/filament/filament/model_objects.h" #include "experimental/filament/filament/object_manager.h" +#include "experimental/filament/filament/texture.h" namespace mujoco { @@ -85,8 +87,11 @@ static bool IsBehind(const mjtNum* headpos, const float* pos, const float* mat) (headpos[2] - pos[2]) * mat[8] < 0.0f); } -Drawable::Drawable(ObjectManager* object_mgr, const mjvGeom& geom) - : material_(object_mgr), renderables_(object_mgr->GetEngine()) { +Drawable::Drawable(ObjectManager* object_mgr, ModelObjects* model_objects, + const mjvGeom& geom) + : material_(object_mgr), + model_objs_(model_objects), + renderables_(object_mgr->GetEngine()) { if (geom.category == mjCAT_DECOR) { renderables_.SetCastShadows(false); renderables_.SetReceiveShadows(false); @@ -100,53 +105,53 @@ Drawable::Drawable(ObjectManager* object_mgr, const mjvGeom& geom) AddHeightField(geom.dataid); break; case mjGEOM_PLANE: - AddShape(ObjectManager::kPlane); + AddShape(ModelObjects::kPlane); break; case mjGEOM_SPHERE: - AddShape(ObjectManager::kSphere); + AddShape(ModelObjects::kSphere); break; case mjGEOM_ELLIPSOID: - AddShape(ObjectManager::kSphere); + AddShape(ModelObjects::kSphere); break; case mjGEOM_BOX: - AddShape(ObjectManager::kBox); + AddShape(ModelObjects::kBox); break; case mjGEOM_CAPSULE: - AddShape(ObjectManager::kTube); - AddShape(ObjectManager::kDome); - AddShape(ObjectManager::kDome); + AddShape(ModelObjects::kTube); + AddShape(ModelObjects::kDome); + AddShape(ModelObjects::kDome); break; case mjGEOM_CYLINDER: - AddShape(ObjectManager::kTube); - AddShape(ObjectManager::kDisk); - AddShape(ObjectManager::kDisk); + AddShape(ModelObjects::kTube); + AddShape(ModelObjects::kDisk); + AddShape(ModelObjects::kDisk); break; case mjGEOM_ARROW: - AddShape(ObjectManager::kTube); - AddShape(ObjectManager::kCone); - AddShape(ObjectManager::kDisk); + AddShape(ModelObjects::kTube); + AddShape(ModelObjects::kCone); + AddShape(ModelObjects::kDisk); break; case mjGEOM_ARROW1: - AddShape(ObjectManager::kTube); - AddShape(ObjectManager::kCone); - AddShape(ObjectManager::kDisk); - AddShape(ObjectManager::kDisk); + AddShape(ModelObjects::kTube); + AddShape(ModelObjects::kCone); + AddShape(ModelObjects::kDisk); + AddShape(ModelObjects::kDisk); break; case mjGEOM_ARROW2: - AddShape(ObjectManager::kTube); - AddShape(ObjectManager::kCone); - AddShape(ObjectManager::kCone); - AddShape(ObjectManager::kDisk); - AddShape(ObjectManager::kDisk); + AddShape(ModelObjects::kTube); + AddShape(ModelObjects::kCone); + AddShape(ModelObjects::kCone); + AddShape(ModelObjects::kDisk); + AddShape(ModelObjects::kDisk); break; case mjGEOM_LINE: - AddShape(ObjectManager::kLine); + AddShape(ModelObjects::kLine); break; case mjGEOM_LINEBOX: - AddShape(ObjectManager::kLineBox); + AddShape(ModelObjects::kLineBox); break; case mjGEOM_TRIANGLE: - AddShape(ObjectManager::kTriangle); + AddShape(ModelObjects::kTriangle); break; case mjGEOM_FLEX: case mjGEOM_SKIN: @@ -187,8 +192,7 @@ void Drawable::Update(const mjModel* model, const mjvScene* scene, } void Drawable::AddMesh(int data_id) { - ObjectManager* object_mgr = material_.GetObjectManager(); - const FilamentBuffers* buffers = object_mgr->GetMeshBuffer(data_id); + const FilamentBuffers* buffers = model_objs_->GetMeshBuffer(data_id); if (buffers == nullptr) { mju_error("Unknown mesh %d", data_id); } @@ -196,17 +200,15 @@ void Drawable::AddMesh(int data_id) { } void Drawable::AddHeightField(int hfield_id) { - ObjectManager* object_mgr = material_.GetObjectManager(); - const FilamentBuffers* buffers = object_mgr->GetHeightFieldBuffer(hfield_id); + const FilamentBuffers* buffers = model_objs_->GetHeightFieldBuffer(hfield_id); if (buffers == nullptr) { mju_error("Unknown height field %d", hfield_id); } renderables_.Append(*buffers); } -void Drawable::AddShape(ObjectManager::ShapeType shape_type) { - ObjectManager* object_mgr = material_.GetObjectManager(); - const FilamentBuffers* buffers = object_mgr->GetShapeBuffer(shape_type); +void Drawable::AddShape(ModelObjects::ShapeType shape_type) { + const FilamentBuffers* buffers = model_objs_->GetShapeBuffer(shape_type); if (buffers == nullptr) { mju_error("Unknown shape %d", shape_type); } @@ -225,7 +227,7 @@ void Drawable::SetDrawMode(Material::DrawMode mode) { renderables_.SetMaterialInstance(material_.GetMaterialInstance(mode)); } -void Drawable::UpdateReflectionTexture(const filament::Texture* tex) { +void Drawable::UpdateReflectionTexture(const Texture* tex) { material_.UpdateReflectionTexture(tex); } @@ -352,8 +354,7 @@ void Drawable::SetTransform(const mjvGeom& geom) { void Drawable::UpdateMaterial(const mjvGeom& geom, bool use_segid_color, bool enable_reflection, const mjtNum* headpos) { - ObjectManager* object_mgr = material_.GetObjectManager(); - const mjModel* model = object_mgr->GetModel(); + const mjModel* model = model_objs_->GetModel(); float4 color = ReadFloat4(geom.rgba); if (geom.type == mjGEOM_PLANE) { @@ -370,15 +371,15 @@ void Drawable::UpdateMaterial(const mjvGeom& geom, bool use_segid_color, Material::Textures textures; if (geom.matid >= 0) { - textures.color = object_mgr->GetTexture(geom.matid, mjTEXROLE_RGB); - textures.normal = object_mgr->GetTexture(geom.matid, mjTEXROLE_NORMAL); - textures.emissive = object_mgr->GetTexture(geom.matid, mjTEXROLE_EMISSIVE); - textures.orm = object_mgr->GetTexture(geom.matid, mjTEXROLE_ORM); - textures.metallic = object_mgr->GetTexture(geom.matid, mjTEXROLE_METALLIC); + textures.color = model_objs_->GetTexture(geom.matid, mjTEXROLE_RGB); + textures.normal = model_objs_->GetTexture(geom.matid, mjTEXROLE_NORMAL); + textures.emissive = model_objs_->GetTexture(geom.matid, mjTEXROLE_EMISSIVE); + textures.orm = model_objs_->GetTexture(geom.matid, mjTEXROLE_ORM); + textures.metallic = model_objs_->GetTexture(geom.matid, mjTEXROLE_METALLIC); textures.roughness = - object_mgr->GetTexture(geom.matid, mjTEXROLE_ROUGHNESS); + model_objs_->GetTexture(geom.matid, mjTEXROLE_ROUGHNESS); textures.occlusion = - object_mgr->GetTexture(geom.matid, mjTEXROLE_OCCLUSION); + model_objs_->GetTexture(geom.matid, mjTEXROLE_OCCLUSION); material_.UpdateTextures(textures); } @@ -422,7 +423,7 @@ void Drawable::UpdateMaterial(const mjvGeom& geom, bool use_segid_color, } else { material_.SetNormalMaterialType(ObjectManager::kPhongColor); } - } else if (textures.color->getTarget() == + } else if (textures.color->GetFilamentTexture()->getTarget() == filament::Texture::Sampler::SAMPLER_CUBEMAP) { if (color.a < 1.0f) { material_.SetNormalMaterialType(ObjectManager::kPhongCubeFade); @@ -490,7 +491,8 @@ void Drawable::UpdateMaterial(const mjvGeom& geom, bool use_segid_color, // the programmatic UVs. if (textures.color) { - if (textures.color->getTarget() == filament::Texture::Sampler::SAMPLER_2D) { + if (textures.color->GetFilamentTexture()->getTarget() == + filament::Texture::Sampler::SAMPLER_2D) { // For 2D textures, `tex_repeat` specifies how many times the texture // image is repeated. The `tex_uniform` flag determines if the repetition // is applied at in object space (false) or in world space (true). @@ -548,9 +550,9 @@ void Drawable::UpdateMaterial(const mjvGeom& geom, bool use_segid_color, } // Apply material multipliers from the model. - params.emissive *= object_mgr->GetEmissiveMultiplier(); - params.specular *= object_mgr->GetSpecularMultiplier(); - params.glossiness *= object_mgr->GetShininessMultiplier(); + params.emissive *= model_objs_->GetEmissiveMultiplier(); + params.specular *= model_objs_->GetSpecularMultiplier(); + params.glossiness *= model_objs_->GetShininessMultiplier(); material_.UpdateParams(params); } diff --git a/src/experimental/filament/filament/drawable.h b/src/experimental/filament/filament/drawable.h index dc1f95d040..2f2dc81e38 100644 --- a/src/experimental/filament/filament/drawable.h +++ b/src/experimental/filament/filament/drawable.h @@ -23,15 +23,18 @@ #include #include #include "experimental/filament/filament/material.h" +#include "experimental/filament/filament/model_objects.h" #include "experimental/filament/filament/object_manager.h" #include "experimental/filament/filament/renderables.h" +#include "experimental/filament/filament/texture.h" namespace mujoco { // Manages the filament Entities and MaterialInstances for a single mjvGeom. class Drawable { public: - Drawable(ObjectManager* object_mgr, const mjvGeom& geom); + Drawable(ObjectManager* object_mgr, ModelObjects* model_objects, + const mjvGeom& geom); ~Drawable() noexcept = default; Drawable(const Drawable&) = delete; @@ -66,12 +69,12 @@ class Drawable { // Sets the reflection texture for the drawable. We have a separate setter // because we need to render the reflection texture before it can be applied // to the material. - void UpdateReflectionTexture(const filament::Texture* tex); + void UpdateReflectionTexture(const Texture* tex); private: void AddMesh(int data_id); void AddHeightField(int hfield_id); - void AddShape(ObjectManager::ShapeType shape_type); + void AddShape(ModelObjects::ShapeType shape_type); // Updates the transform of the drawable for rendering. void SetTransform(const mjvGeom& geom); @@ -81,6 +84,7 @@ class Drawable { bool enable_reflection, const mjtNum* headpos); Material material_; + ModelObjects* model_objs_ = nullptr; Renderables renderables_; bool reflective_ = false; filament::math::mat4 transform_; diff --git a/src/experimental/filament/filament/filament_context.cc b/src/experimental/filament/filament/filament_context.cc index 0fa4b14d5d..cbd06da21e 100644 --- a/src/experimental/filament/filament/filament_context.cc +++ b/src/experimental/filament/filament/filament_context.cc @@ -41,10 +41,11 @@ #include "experimental/filament/filament/filament_platform_factory.h" #include "experimental/filament/filament/gui_view.h" #include "experimental/filament/filament/imgui_editor.h" -#include "experimental/filament/filament/object_manager.h" #include "experimental/filament/filament/model_util.h" -#include "experimental/filament/filament/scene_view.h" +#include "experimental/filament/filament/object_manager.h" #include "experimental/filament/filament/render_target_util.h" +#include "experimental/filament/filament/scene_view.h" +#include "experimental/filament/filament/texture.h" #include "experimental/filament/render_context_filament.h" namespace mujoco { @@ -62,16 +63,20 @@ FilamentContext::FilamentContext(const mjrFilamentConfig* config) engine_ = engine_builder.build(); renderer_ = engine_->createRenderer(); - #ifdef __EMSCRIPTEN__ - window_swap_chain_ = engine_->createSwapChain(nullptr); - #else +#ifdef __EMSCRIPTEN__ + window_swap_chain_ = engine_->createSwapChain(nullptr); +#else if (config_.native_window) { window_swap_chain_ = engine_->createSwapChain(config_.native_window); } else { - window_swap_chain_ = engine_->createSwapChain(config_.width, config_.height); + window_swap_chain_ = + engine_->createSwapChain(config_.width, config_.height); } - #endif - offscreen_swap_chain_ = engine_->createSwapChain(config_.width, config_.height); +#endif + offscreen_swap_chain_ = + engine_->createSwapChain(config_.width, config_.height); + + object_manager_ = std::make_unique(engine_); } FilamentContext::~FilamentContext() { @@ -86,8 +91,7 @@ FilamentContext::~FilamentContext() { } void FilamentContext::Init(const mjModel* model) { - object_manager_ = std::make_unique(model, engine_); - scene_view_ = std::make_unique(engine_, object_manager_.get()); + scene_view_ = std::make_unique(object_manager_.get(), model); gui_view_ = std::make_unique( engine_, object_manager_->GetMaterial(ObjectManager::kUnlitUi)); @@ -182,11 +186,13 @@ void FilamentContext::SetFrameBuffer(int framebuffer) { void FilamentContext::PrepareRenderTargets(int width, int height) { color_target_ = std::make_unique( - engine_, kRenderTargetColor, kRenderTargetDepth); + engine_, RenderTargetTextureType::kColor, + RenderTargetTextureType::kDepth); color_target_->Prepare(width, height); depth_target_ = std::make_unique( - engine_, kRenderTargetDepthColor, kRenderTargetDepth); + engine_, RenderTargetTextureType::kDepthColor, + RenderTargetTextureType::kDepth); depth_target_->Prepare(width, height); } @@ -270,15 +276,24 @@ void FilamentContext::ReadPixels(mjrRect viewport, unsigned char* rgb, } void FilamentContext::UploadMesh(const mjModel* model, int id) { - object_manager_->UploadMesh(model, id); + if (!scene_view_) { + mju_error("SceneView is not initialized."); + } + scene_view_->UploadMesh(model, id); } void FilamentContext::UploadTexture(const mjModel* model, int id) { - object_manager_->UploadTexture(model, id); + if (!scene_view_) { + mju_error("SceneView is not initialized."); + } + scene_view_->UploadTexture(model, id); } void FilamentContext::UploadHeightField(const mjModel* model, int id) { - object_manager_->UploadHeightField(model, id); + if (!scene_view_) { + mju_error("SceneView is not initialized."); + } + scene_view_->UploadHeightField(model, id); } uintptr_t FilamentContext::UploadGuiImage(uintptr_t tex_id, @@ -300,8 +315,6 @@ double FilamentContext::GetFrameRate() const { return 1.0e9 / static_cast(ns); } -void FilamentContext::UpdateGui() { - DrawGui(scene_view_.get()); -} +void FilamentContext::UpdateGui() { DrawGui(scene_view_.get()); } } // namespace mujoco diff --git a/src/experimental/filament/filament/filament_context.h b/src/experimental/filament/filament/filament_context.h index ec0a724347..7e2e2a9aba 100644 --- a/src/experimental/filament/filament/filament_context.h +++ b/src/experimental/filament/filament/filament_context.h @@ -26,8 +26,8 @@ #include #include "experimental/filament/filament/gui_view.h" #include "experimental/filament/filament/object_manager.h" -#include "experimental/filament/filament/scene_view.h" #include "experimental/filament/filament/render_target_util.h" +#include "experimental/filament/filament/scene_view.h" #include "experimental/filament/render_context_filament.h" namespace mujoco { diff --git a/src/experimental/filament/filament/geom_util.cc b/src/experimental/filament/filament/geom_util.cc index fed02258f8..a26fd07317 100644 --- a/src/experimental/filament/filament/geom_util.cc +++ b/src/experimental/filament/filament/geom_util.cc @@ -14,12 +14,16 @@ #include "experimental/filament/filament/geom_util.h" +#include #include #include #include #include #include +#include +#include +#include #include #include "experimental/filament/filament/buffer_util.h" #include "experimental/filament/filament/math_util.h" @@ -27,6 +31,8 @@ namespace mujoco { +using filament::math::float3; + static std::span GetPositions(const mjModel* model, const mjvScene* scene, const mjvGeom& geom) { @@ -102,11 +108,15 @@ template static void FillVertices(std::byte* buffer, std::size_t len, std::span positions, std::span normals, - std::span uvs) { + std::span uvs, + float3* vmin, + float3* vmax) { const int num_vertices = len / sizeof(T); T* ptr = reinterpret_cast(buffer); for (int i = 0; i < num_vertices; ++i) { ptr->position = ReadFloat3(positions.data(), i); + *vmin = min(*vmin, ptr->position); + *vmax = max(*vmax, ptr->position); ptr->orientation = CalculateOrientation(ReadFloat3(normals.data(), i)); if constexpr (T::kHasUv) { ptr->uv.x = uvs[i * 2]; @@ -118,18 +128,21 @@ static void FillVertices(std::byte* buffer, std::size_t len, static filament::VertexBuffer* BuildVertexBuffer( filament::Engine* engine, std::span positions, - std::span normals, std::span uvs) { + std::span normals, std::span uvs, float3* vmin, + float3* vmax) { const int num_vertices = positions.size() / 3; if (uvs.data() != nullptr) { using VertexType = VertexWithUv; auto fill = [&](std::byte* buffer, std::size_t len) { - FillVertices(buffer, len, positions, normals, uvs); + FillVertices(buffer, len, positions, normals, uvs, vmin, + vmax); }; return CreateVertexBuffer(engine, num_vertices, fill); } else { using VertexType = VertexNoUv; auto fill = [&](std::byte* buffer, std::size_t len) { - FillVertices(buffer, len, positions, normals, uvs); + FillVertices(buffer, len, positions, normals, uvs, vmin, + vmax); }; return CreateVertexBuffer(engine, num_vertices, fill); } @@ -163,8 +176,12 @@ FilamentBuffers CreateGeomBuffers(filament::Engine* engine, } FilamentBuffers buffers; - buffers.vertex_buffer = BuildVertexBuffer(engine, positions, normals, uvs); + float3 vmin = {FLT_MAX, FLT_MAX, FLT_MAX}; + float3 vmax = {-FLT_MAX, -FLT_MAX, -FLT_MAX}; + buffers.vertex_buffer = + BuildVertexBuffer(engine, positions, normals, uvs, &vmin, &vmax); buffers.index_buffer = BuildIndexBuffer(engine, indices, num_indices); + buffers.bounds.emplace().set(vmin, vmax); return buffers; } diff --git a/src/experimental/filament/filament/gui_view.cc b/src/experimental/filament/filament/gui_view.cc index 8fa83cae55..1876936eb1 100644 --- a/src/experimental/filament/filament/gui_view.cc +++ b/src/experimental/filament/filament/gui_view.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -31,6 +32,7 @@ #include #include #include "experimental/filament/filament/buffer_util.h" +#include "experimental/filament/filament/texture.h" #include "experimental/filament/filament/vertex_util.h" namespace mujoco { @@ -67,9 +69,7 @@ GuiView::~GuiView() { for (auto& instance : instances_) { engine_->destroy(instance); } - for (auto& texture : textures_) { - engine_->destroy(texture.second); - } + textures_.clear(); engine_->destroyCameraComponent(camera_->getEntity()); engine_->destroy(view_); engine_->destroy(scene_); @@ -98,62 +98,52 @@ uintptr_t GuiView::UploadImage(uintptr_t tex_id, const uint8_t* pixels, mju_error("Unsupported image bpp. Got %d, wanted 3 or 4", bpp); } - const auto internal_format = - bpp == 4 ? filament::Texture::InternalFormat::RGBA8 - : filament::Texture::InternalFormat::RGB8; - const auto texture_format = bpp == 4 ? filament::Texture::Format::RGBA - : filament::Texture::Format::RGB; + if (pixels == nullptr) { + // If the pixels are nullptr, we destroy the texture. + if (tex_id != 0) { + textures_.erase(tex_id); + } + return 0; + } - filament::Texture* texture = nullptr; + // Assign a new texture ID. if (tex_id == 0) { - texture = filament::Texture::Builder() - .width(width) - .height(height) - .levels(1) - .format(internal_format) - .sampler(filament::Texture::Sampler::SAMPLER_2D) - .build(*engine_); tex_id = textures_.size() + 1; - textures_[tex_id] = texture; - } else { - auto iter = textures_.find(tex_id); - if (iter == textures_.end()) { - mju_error("Texture not found: %lu", tex_id); - } - texture = iter->second; + } - if (pixels == nullptr) { - // A nullptr implies that the user wants to destroy the texture. - engine_->destroy(texture); - textures_.erase(tex_id); - return 0; - } else if (texture->getWidth() != width || texture->getHeight() != height) { - // Recreate the texture if the dimensions have changed. - engine_->destroy(texture); - texture = filament::Texture::Builder() - .width(width) - .height(height) - .levels(1) - .format(internal_format) - .sampler(filament::Texture::Sampler::SAMPLER_2D) - .build(*engine_); - textures_[tex_id] = texture; - } + std::unique_ptr& texture = textures_[tex_id]; + + // If the texture does not exist or the dimensions have changed, we create a + // new texture. + if (texture == nullptr || texture->GetWidth() != width || + texture->GetHeight() != height) { + TextureConfig config; + DefaultTextureConfig(&config); + config.width = width; + config.height = height; + config.target = mjTEXTURE_2D; + config.format = bpp == 4 ? mjPIXEL_FORMAT_RGBA8 : mjPIXEL_FORMAT_RGB8; + config.color_space = mjCOLORSPACE_LINEAR; + texture = std::make_unique(engine_, config); } // Create a copy of the image to pass it to filament as we don't know the // lifetime of the data. - const int num_bytes = width * height * bpp; + const size_t num_bytes = width * height * bpp; std::byte* bytes = new std::byte[num_bytes]; - std::memcpy(bytes, pixels, num_bytes); - const auto callback = [](void* buffer, size_t size, void* user) { - auto* ptr = reinterpret_cast(user); - delete[] ptr; + const auto callback = +[](void* user) { + delete[] reinterpret_cast(user); }; - filament::Texture::PixelBufferDescriptor pb(bytes, num_bytes, texture_format, - filament::Texture::Type::UBYTE, - callback); - texture->setImage(*engine_, 0, std::move(pb)); + + TextureData texture_data; + DefaultTextureData(&texture_data); + texture_data.bytes = bytes; + texture_data.nbytes = num_bytes; + texture_data.user_data = bytes; + texture_data.release_callback = callback; + + std::memcpy(bytes, pixels, num_bytes); + texture->Upload(texture_data); return tex_id; } @@ -162,40 +152,39 @@ void GuiView::CreateTexture(ImTextureData* data) { mju_error("Unsupported texture format."); } - filament::Texture* texture = - filament::Texture::Builder() - .width(data->Width) - .height(data->Height) - .levels(1) - .format(filament::Texture::InternalFormat::RGBA8) - .sampler(filament::Texture::Sampler::SAMPLER_2D) - .build(*engine_); + TextureConfig config; + DefaultTextureConfig(&config); + config.width = data->Width; + config.height = data->Height; + config.target = mjTEXTURE_2D; + config.format = mjPIXEL_FORMAT_RGBA8; + config.color_space = mjCOLORSPACE_LINEAR; const uintptr_t tex_id = textures_.size() + 1; - textures_[tex_id] = texture; + textures_[tex_id] = std::make_unique(engine_, config); data->SetTexID((ImTextureID)tex_id); UpdateTexture(data); } void GuiView::UpdateTexture(ImTextureData* data) { - const int size = data->Width * data->Height * 4; - filament::Texture::PixelBufferDescriptor pb(data->GetPixels(), size, - filament::Texture::Format::RGBA, - filament::Texture::Type::UBYTE); auto iter = textures_.find(data->TexID); if (iter == textures_.end()) { mju_error("Texture not found: %llu", data->TexID); } - filament::Texture* texture = iter->second; - texture->setImage(*engine_, 0, std::move(pb)); + TextureData texture_data; + DefaultTextureData(&texture_data); + texture_data.bytes = data->GetPixels(); + texture_data.nbytes = data->Width * data->Height * 4; + texture_data.user_data = nullptr; + texture_data.release_callback = nullptr; + iter->second->Upload(texture_data); data->SetStatus(ImTextureStatus_OK); } void GuiView::DestroyTexture(ImTextureData* data) { auto iter = textures_.find(data->TexID); if (iter != textures_.end()) { - engine_->destroy(iter->second); textures_.erase(data->TexID); data->SetTexID(ImTextureID_Invalid); data->SetStatus(ImTextureStatus_Destroyed); @@ -361,7 +350,8 @@ filament::MaterialInstance* GuiView::GetMaterialInstance(int index, } filament::MaterialInstance* instance = instances_[index]; - instance->setParameter("glyph", iter->second, filament::TextureSampler()); + instance->setParameter("glyph", iter->second->GetFilamentTexture(), + filament::TextureSampler()); instance->setScissor(rect.left, rect.bottom, rect.width, rect.height); return instance; } diff --git a/src/experimental/filament/filament/gui_view.h b/src/experimental/filament/filament/gui_view.h index 2adb6d07af..2b189324f6 100644 --- a/src/experimental/filament/filament/gui_view.h +++ b/src/experimental/filament/filament/gui_view.h @@ -16,6 +16,7 @@ #define MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_GUI_VIEW_H_ #include +#include #include #include @@ -29,6 +30,7 @@ #include #include #include "experimental/filament/filament/buffer_util.h" +#include "experimental/filament/filament/texture.h" namespace mujoco { @@ -71,7 +73,7 @@ class GuiView { utils::Entity renderable_; std::vector buffers_; std::vector instances_; - std::unordered_map textures_; + std::unordered_map> textures_; int num_elements_ = 0; }; diff --git a/src/experimental/filament/filament/material.cc b/src/experimental/filament/filament/material.cc index 1557e9e55c..aaf9c6ba42 100644 --- a/src/experimental/filament/filament/material.cc +++ b/src/experimental/filament/filament/material.cc @@ -21,12 +21,13 @@ #include #include #include "experimental/filament/filament/object_manager.h" +#include "experimental/filament/filament/texture.h" namespace mujoco { Material::Material(ObjectManager* object_mgr) : object_mgr_(object_mgr) { instances_[kDepth] = - object_mgr->GetMaterial(ObjectManager::kUnlitDepth)->createInstance(); + object_mgr_->GetMaterial(ObjectManager::kUnlitDepth)->createInstance(); instances_[kSegmentation] = object_mgr_->GetMaterial(ObjectManager::kUnlitSegmentation) ->createInstance(); @@ -71,7 +72,7 @@ void Material::UpdateTextures(const Textures& textures) { UpdateMaterialInstances(); } -void Material::UpdateReflectionTexture(const filament::Texture* tex) { +void Material::UpdateReflectionTexture(const Texture* tex) { textures_.reflection = tex; UpdateMaterialInstances(); } @@ -128,70 +129,26 @@ void Material::UpdateMaterialInstances() { sampler.setMinFilter( filament::TextureSampler::MinFilter::LINEAR_MIPMAP_LINEAR); - if (material->hasParameter("BaseColor")) { - if (textures_.color) { - instance->setParameter("BaseColor", textures_.color, sampler); - } else { - auto* fallback = object_mgr_->GetFallbackTexture(mjTEXROLE_RGB); - instance->setParameter("BaseColor", fallback, sampler); + auto TrySetTexture = [&](const char* name, const Texture* texture, + mjtTextureRole role) { + if (material->hasParameter(name)) { + if (texture) { + instance->setParameter(name, texture->GetFilamentTexture(), sampler); + } else { + auto* fallback = object_mgr_->GetFallbackTexture(role); + instance->setParameter(name, fallback->GetFilamentTexture(), sampler); + } } - } - if (material->hasParameter("Normal")) { - if (textures_.normal) { - instance->setParameter("Normal", textures_.normal, sampler); - } else { - auto* fallback = object_mgr_->GetFallbackTexture(mjTEXROLE_NORMAL); - instance->setParameter("Normal", fallback, sampler); - } - } - if (material->hasParameter("Metallic")) { - if (textures_.metallic) { - instance->setParameter("Metallic", textures_.metallic, sampler); - } else { - auto* fallback = object_mgr_->GetFallbackTexture(mjTEXROLE_METALLIC); - instance->setParameter("Metallic", fallback, sampler); - } - } - if (material->hasParameter("Roughness")) { - if (textures_.roughness) { - instance->setParameter("Roughness", textures_.roughness, sampler); - } else { - auto* fallback = object_mgr_->GetFallbackTexture(mjTEXROLE_ROUGHNESS); - instance->setParameter("Roughness", fallback, sampler); - } - } - if (material->hasParameter("Occlusion")) { - if (textures_.occlusion) { - instance->setParameter("Occlusion", textures_.occlusion, sampler); - } else { - auto* fallback = object_mgr_->GetFallbackTexture(mjTEXROLE_OCCLUSION); - instance->setParameter("Occlusion", fallback, sampler); - } - } - if (material->hasParameter("ORM")) { - if (textures_.orm) { - instance->setParameter("ORM", textures_.orm, sampler); - } else { - auto* fallback = object_mgr_->GetFallbackTexture(mjTEXROLE_ORM); - instance->setParameter("ORM", fallback, sampler); - } - } - if (material->hasParameter("Emissive")) { - if (textures_.emissive) { - instance->setParameter("Emissive", textures_.emissive, sampler); - } else { - auto* fallback = object_mgr_->GetFallbackTexture(mjTEXROLE_EMISSIVE); - instance->setParameter("Emissive", fallback, sampler); - } - } - if (material->hasParameter("Reflection")) { - if (textures_.reflection) { - instance->setParameter("Reflection", textures_.reflection, sampler); - } else { - auto* fallback = object_mgr_->GetFallbackTexture(mjTEXROLE_USER); - instance->setParameter("Reflection", fallback, sampler); - } - } + }; + + TrySetTexture("BaseColor", textures_.color, mjTEXROLE_RGB); + TrySetTexture("Normal", textures_.normal, mjTEXROLE_NORMAL); + TrySetTexture("Metallic", textures_.metallic, mjTEXROLE_METALLIC); + TrySetTexture("Roughness", textures_.roughness, mjTEXROLE_ROUGHNESS); + TrySetTexture("Occlusion", textures_.occlusion, mjTEXROLE_OCCLUSION); + TrySetTexture("ORM", textures_.orm, mjTEXROLE_ORM); + TrySetTexture("Emissive", textures_.emissive, mjTEXROLE_EMISSIVE); + TrySetTexture("Reflection", textures_.reflection, mjTEXROLE_USER); } } // namespace mujoco diff --git a/src/experimental/filament/filament/material.h b/src/experimental/filament/filament/material.h index 22452fb611..b717b9c41d 100644 --- a/src/experimental/filament/filament/material.h +++ b/src/experimental/filament/filament/material.h @@ -17,11 +17,11 @@ #include #include -#include #include #include #include #include "experimental/filament/filament/object_manager.h" +#include "experimental/filament/filament/texture.h" namespace mujoco { @@ -39,14 +39,14 @@ class Material { // The textures that can be assigned to the drawable's material. struct Textures { - const filament::Texture* color = nullptr; - const filament::Texture* normal = nullptr; - const filament::Texture* metallic = nullptr; - const filament::Texture* roughness = nullptr; - const filament::Texture* occlusion = nullptr; - const filament::Texture* orm = nullptr; - const filament::Texture* emissive = nullptr; - const filament::Texture* reflection = nullptr; + const Texture* color = nullptr; + const Texture* normal = nullptr; + const Texture* metallic = nullptr; + const Texture* roughness = nullptr; + const Texture* occlusion = nullptr; + const Texture* orm = nullptr; + const Texture* emissive = nullptr; + const Texture* reflection = nullptr; }; // The parameters that can be applied to the drawable's material. @@ -82,17 +82,13 @@ class Material { // Update the reflection texture. We do this separately since the reflection // texture needs to be rendered before it can be applied to the material. - void UpdateReflectionTexture(const filament::Texture* tex); + void UpdateReflectionTexture(const Texture* tex); // Returns the material instance assigned to the draw mode. filament::MaterialInstance* GetMaterialInstance(DrawMode mode) { return instances_[mode]; } - // Returns the ObjectManager owning the Materials which are used to create - // the MaterialInstances. - ObjectManager* GetObjectManager() { return object_mgr_; } - private: // Updates the material instances based on the currently set parameters and // textures. diff --git a/src/experimental/filament/filament/model_objects.cc b/src/experimental/filament/filament/model_objects.cc new file mode 100644 index 0000000000..3df4f2e220 --- /dev/null +++ b/src/experimental/filament/filament/model_objects.cc @@ -0,0 +1,272 @@ +// Copyright 2026 DeepMind Technologies Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "experimental/filament/filament/model_objects.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include "experimental/filament/filament/buffer_util.h" +#include "experimental/filament/filament/builtins.h" +#include "experimental/filament/filament/model_util.h" +#include "experimental/filament/filament/texture.h" + + +namespace mujoco { + +ModelObjects::ModelObjects(const mjModel* model, filament::Engine* engine) + : model_(model), engine_(engine) { + const int nstack = model->vis.quality.numstacks; + const int nslice = model->vis.quality.numslices; + const int nquad = model->vis.quality.numquads; + shapes_[kLine] = CreateLine(engine_); + shapes_[kBox] = CreateBox(engine_, nquad); + shapes_[kLineBox] = CreateLineBox(engine_); + shapes_[kCone] = CreateCone(engine_, nstack, nslice); + shapes_[kDisk] = CreateDisk(engine_, nslice); + shapes_[kDome] = CreateDome(engine_, nstack / 2, nslice); + shapes_[kTube] = CreateTube(engine_, nstack, nslice); + shapes_[kPlane] = CreatePlane(engine_, nquad); + shapes_[kSphere] = CreateSphere(engine_, nstack, nslice); + shapes_[kTriangle] = CreateTriangle(engine_); + + for (int i = 0; i < model_->ntex; ++i) { + UploadTexture(model_, i); + } + for (int i = 0; i < model_->nmesh; ++i) { + UploadMesh(model_, i); + } + for (int i = 0; i < model_->nhfield; ++i) { + UploadHeightField(model_, i); + } + + specular_multiplier_ = ReadElement( + model_, "filament.phong.specular_multiplier", specular_multiplier_); + shininess_multiplier_ = ReadElement( + model_, "filament.phong.shininess_multiplier", shininess_multiplier_); + emissive_multiplier_ = ReadElement( + model_, "filament.phong.emissive_multiplier", emissive_multiplier_); +} + +ModelObjects::~ModelObjects() { + for (auto& iter : skyboxes_) { + engine_->destroy(iter); + } + for (auto& iter : indirect_lights_) { + engine_->destroy(iter); + } + for (auto& iter : meshes_) { + engine_->destroy(iter.second.vertex_buffer); + engine_->destroy(iter.second.index_buffer); + } + for (auto& iter : shapes_) { + engine_->destroy(iter.vertex_buffer); + engine_->destroy(iter.index_buffer); + } + textures_.clear(); +} + +void ModelObjects::UploadMesh(const mjModel* model, int id) { + if (model != model_) { + mju_error("Model mismatch."); + } + if (id < 0 || id >= model->nmesh) { + mju_error("Invalid mesh index %d", id); + } + + if (auto iter = meshes_.find(id); iter != meshes_.end()) { + engine_->destroy(iter->second.vertex_buffer); + engine_->destroy(iter->second.index_buffer); + } + if (auto iter = convex_hulls_.find(id); iter != convex_hulls_.end()) { + engine_->destroy(iter->second.vertex_buffer); + engine_->destroy(iter->second.index_buffer); + } + + FilamentBuffers& buffers = meshes_[id]; + buffers.vertex_buffer = CreateVertexBuffer( + engine_, model, id, MeshType::kNormal, &buffers.bounds.emplace()); + buffers.index_buffer = + CreateIndexBuffer(engine_, model, id, MeshType::kNormal); + + if (model->mesh_graphadr[id] >= 0) { + FilamentBuffers& hull_buffers = convex_hulls_[id]; + hull_buffers.vertex_buffer = + CreateVertexBuffer(engine_, model, id, MeshType::kConvexHull, + &hull_buffers.bounds.emplace()); + hull_buffers.index_buffer = + CreateIndexBuffer(engine_, model, id, MeshType::kConvexHull); + } +} + +void ModelObjects::UploadTexture(const mjModel* model, int id) { + if (model != model_) { + mju_error("Model mismatch."); + } + if (id < 0 || id >= model->ntex) { + mju_error("Invalid texture index: %d", id); + } + + TextureConfig config; + DefaultTextureConfig(&config); + config.width = model->tex_width[id]; + config.height = model->tex_height[id]; + config.target = (mjtTexture)model->tex_type[id]; + config.color_space = (mjtColorSpace)model->tex_colorspace[id]; + switch (model->tex_nchannel[id]) { + case 1: + config.format = mjPIXEL_FORMAT_R8; + break; + case 3: + config.format = mjPIXEL_FORMAT_RGB8; + break; + case 4: + config.format = mjPIXEL_FORMAT_RGBA8; + break; + default: + mju_error("Unsupported texture format: %d", model->tex_nchannel[id]); + break; + } + if (config.height == 1 && model->tex_nchannel[id] == 1) { + config.format = mjPIXEL_FORMAT_KTX; + } + + + TextureData payload; + DefaultTextureData(&payload); + payload.bytes = model->tex_data + model->tex_adr[id]; + payload.nbytes = + model->tex_width[id] * model->tex_height[id] * model->tex_nchannel[id]; + // We assume that the model has the same lifetime as the engine. + payload.user_data = nullptr; + payload.release_callback = nullptr; + + auto texture = std::make_unique(engine_, config); + texture->Upload(payload); + textures_[id] = std::move(texture); +} + +void ModelObjects::UploadHeightField(const mjModel* model, int id) { + if (model != model_) { + mju_error("Model mismatch."); + } + if (id < 0 || id >= model->nhfield) { + mju_error("Invalid height field index %d", id); + } + + if (auto iter = height_fields_.find(id); iter != height_fields_.end()) { + engine_->destroy(iter->second.vertex_buffer); + engine_->destroy(iter->second.index_buffer); + } + + FilamentBuffers& buffers = height_fields_[id]; + buffers.vertex_buffer = CreateVertexBuffer( + engine_, model, id, MeshType::kHeightField, &buffers.bounds.emplace()); + buffers.index_buffer = + CreateIndexBuffer(engine_, model, id, MeshType::kHeightField); +} + +const FilamentBuffers* ModelObjects::GetMeshBuffer(int data_id) const { + // As defined by mjv_updateScene: + // original mesh: mesh_id * 2 + // convex hull: (mesh_id * 2) + 1 + const int mesh_id = data_id / 2; + if (data_id % 2 == 0) { + auto it = meshes_.find(mesh_id); + return it != meshes_.end() ? &it->second : nullptr; + } else { + auto it = convex_hulls_.find(mesh_id); + return it != convex_hulls_.end() ? &it->second : nullptr; + } +} + +const FilamentBuffers* ModelObjects::GetHeightFieldBuffer( + int hfield_id) const { + auto it = height_fields_.find(hfield_id); + return it != height_fields_.end() ? &it->second : nullptr; +} + +const FilamentBuffers* ModelObjects::GetShapeBuffer(ShapeType shape) const { + if (shape < 0 || shape >= kNumShapes) { + mju_error("Invalid shape type: %d", shape); + } + return &shapes_[shape]; +} + +const Texture* ModelObjects::GetTexture(int tex_id) const { + auto it = textures_.find(tex_id); + return it != textures_.end() ? it->second.get() : nullptr; +} + +const Texture* ModelObjects::GetTexture(int mat_id, int role) const { + if (mat_id < 0 || mat_id >= model_->nmat || role < 0 || role >= mjNTEXROLE) { + return nullptr; + } + const int tex_id = model_->mat_texid[mat_id * mjNTEXROLE + role]; + return GetTexture(tex_id); +} + +filament::IndirectLight* ModelObjects::CreateIndirectLight(int tex_id, + float intensity) { + filament::Texture* texture = nullptr; + const Texture::SphericalHarmonics* spherical_harmonics = nullptr; + auto texture_iter = textures_.find(tex_id); + if (texture_iter != textures_.end()) { + texture = texture_iter->second->GetFilamentTexture(); + spherical_harmonics = texture_iter->second->GetSphericalHarmonics(); + } + + filament::IndirectLight::Builder builder; + builder.reflections(texture); + if (spherical_harmonics != nullptr) { + builder.irradiance(3, *spherical_harmonics); + } + builder.intensity(intensity); + // Rotate the light to match mujoco's Z-up convention. + builder.rotation(filament::math::mat3f::rotation( + filament::math::f::PI / 2, filament::math::float3{1, 0, 0})); + filament::IndirectLight* indirect_light = builder.build(*engine_); + indirect_lights_.push_back(indirect_light); + return indirect_light; +} + +filament::Skybox* ModelObjects::CreateSkybox() { + filament::Texture* skybox_texture = nullptr; + for (auto& iter : textures_) { + const int texture_type = model_->tex_type[iter.first]; + if (texture_type == mjTEXTURE_SKYBOX) { + skybox_texture = iter.second->GetFilamentTexture(); + break; + } + } + + if (skybox_texture == nullptr) { + return nullptr; + } + + filament::Skybox::Builder builder; + builder.environment(skybox_texture); + filament::Skybox* skybox = builder.build(*engine_); + skyboxes_.push_back(skybox); + return skybox; +} + +} // namespace mujoco diff --git a/src/experimental/filament/filament/model_objects.h b/src/experimental/filament/filament/model_objects.h new file mode 100644 index 0000000000..7007e81a3e --- /dev/null +++ b/src/experimental/filament/filament/model_objects.h @@ -0,0 +1,99 @@ +// Copyright 2026 DeepMind Technologies Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_MODEL_OBJECTS_H_ +#define MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_MODEL_OBJECTS_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include "experimental/filament/filament/buffer_util.h" +#include "experimental/filament/filament/texture.h" + +namespace mujoco { + +// Creates and owns various filament objects based on the data in a mjrContext. +class ModelObjects { + public: + ModelObjects(const mjModel* model, filament::Engine* engine); + ~ModelObjects(); + + enum ShapeType { + kLine, + kLineBox, + kPlane, + kTriangle, + kBox, + kSphere, + kCone, + kDisk, + kDome, + kTube, + kNumShapes, + }; + + void UploadMesh(const mjModel* model, int id); + + void UploadTexture(const mjModel* model, int id); + + void UploadHeightField(const mjModel* model, int id); + + // Returns the filament engine used by the ModelObjects to create filament + // objects. + filament::Engine* GetEngine() const { return engine_; } + + // Returns the cached instance of a filament object created from the mjModel. + const FilamentBuffers* GetShapeBuffer(ShapeType shape) const; + const FilamentBuffers* GetMeshBuffer(int data_id) const; + const FilamentBuffers* GetHeightFieldBuffer(int hfield_id) const; + const Texture* GetTexture(int tex_id) const; + const Texture* GetTexture(int mat_id, int role) const; + + filament::Skybox* CreateSkybox(); + filament::IndirectLight* CreateIndirectLight(int tex_id, float intensity); + + float GetSpecularMultiplier() const { return specular_multiplier_; } + float GetShininessMultiplier() const { return shininess_multiplier_; } + float GetEmissiveMultiplier() const { return emissive_multiplier_; } + + const mjModel* GetModel() const { return model_; } + + ModelObjects(const ModelObjects&) = delete; + ModelObjects& operator=(const ModelObjects&) = delete; + + private: + const mjModel* model_ = nullptr; + filament::Engine* engine_ = nullptr; + std::vector skyboxes_; + std::vector indirect_lights_; + std::array shapes_; + std::unordered_map meshes_; + std::unordered_map convex_hulls_; + std::unordered_map height_fields_; + std::unordered_map> textures_; + float specular_multiplier_ = 0.2f; + float shininess_multiplier_ = 0.1f; + float emissive_multiplier_ = 0.3f; +}; + +} // namespace mujoco + +#endif // MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_MODEL_OBJECTS_H_ diff --git a/src/experimental/filament/filament/model_util.cc b/src/experimental/filament/filament/model_util.cc index 77fc83a847..2bd581ddff 100644 --- a/src/experimental/filament/filament/model_util.cc +++ b/src/experimental/filament/filament/model_util.cc @@ -15,10 +15,12 @@ #include "experimental/filament/filament/model_util.h" #include +#include #include #include #include +#include #include #include #include @@ -30,7 +32,6 @@ #include #include "experimental/filament/filament/buffer_util.h" #include "experimental/filament/filament/math_util.h" -#include "experimental/filament/filament/texture_util.h" #include "experimental/filament/filament/vertex_util.h" namespace mujoco { @@ -48,11 +49,20 @@ static bool UseFaceNormal(const float3& face_normal, // clang-format on } +static void UpdateBounds(const float3& v, float3* vmin, float3* vmax) { + vmin->x = std::min(vmin->x, v.x); + vmin->y = std::min(vmin->y, v.y); + vmin->z = std::min(vmin->z, v.z); + vmax->x = std::max(vmax->x, v.x); + vmax->y = std::max(vmax->y, v.y); + vmax->z = std::max(vmax->z, v.z); +} + template static void FillConvexHullBuffer(T* ptr, std::size_t num, const mjModel* model, - int meshid) { + int meshid, float3* vmin, float3* vmax) { const int numvert = model->mesh_graph[model->mesh_graphadr[meshid]]; - const int numface = model->mesh_graph[model->mesh_graphadr[meshid]+1]; + const int numface = model->mesh_graph[model->mesh_graphadr[meshid] + 1]; const int vertadr = model->mesh_vertadr[meshid]; const float* vertices = model->mesh_vert + (3 * vertadr); @@ -65,13 +75,18 @@ static void FillConvexHullBuffer(T* ptr, std::size_t num, const mjModel* model, } for (int face = 0; face < numface; ++face) { - int j = model->mesh_graphadr[meshid] + 2 + 3*numvert + 3*numface + 3*face; + int j = + model->mesh_graphadr[meshid] + 2 + 3 * numvert + 3 * numface + 3 * face; const float3 p1 = ReadFloat3(vertices, model->mesh_graph[j + 0]); const float3 p2 = ReadFloat3(vertices, model->mesh_graph[j + 1]); const float3 p3 = ReadFloat3(vertices, model->mesh_graph[j + 2]); const float4 orientation = CalculateOrientation(p1, p2, p3); + UpdateBounds(p1, vmin, vmax); + UpdateBounds(p2, vmin, vmax); + UpdateBounds(p3, vmin, vmax); + ptr->position = p1; ptr->orientation = orientation; if constexpr (T::kHasUv) { @@ -97,7 +112,7 @@ static void FillConvexHullBuffer(T* ptr, std::size_t num, const mjModel* model, template static void FillMeshBuffer(T* ptr, std::size_t num, const mjModel* model, - int meshid) { + int meshid, float3* vmin, float3* vmax) { const int faceadr = model->mesh_faceadr[meshid]; const int facenum = model->mesh_facenum[meshid]; if (num != facenum * 3) { @@ -118,8 +133,12 @@ static void FillMeshBuffer(T* ptr, std::size_t num, const mjModel* model, const float3 p1 = ReadFloat3(vertices, model->mesh_face[face + 0]); const float3 p2 = ReadFloat3(vertices, model->mesh_face[face + 1]); const float3 p3 = ReadFloat3(vertices, model->mesh_face[face + 2]); - const float3 face_normal = CalculateNormal(p1, p2, p3); + UpdateBounds(p1, vmin, vmax); + UpdateBounds(p2, vmin, vmax); + UpdateBounds(p3, vmin, vmax); + + const float3 face_normal = CalculateNormal(p1, p2, p3); const float3 n1 = ReadFloat3(normals, model->mesh_facenormal[face + 0]); const float3 n2 = ReadFloat3(normals, model->mesh_facenormal[face + 1]); const float3 n3 = ReadFloat3(normals, model->mesh_facenormal[face + 2]); @@ -160,7 +179,8 @@ static void FillMeshBuffer(T* ptr, std::size_t num, const mjModel* model, } static void FillHeightFieldBuffer(VertexNoUv* ptr, std::size_t num, - const mjModel* model, int hfieldid) { + const mjModel* model, int hfieldid, + float3* vmin, float3* vmax) { int count = 0; auto append_tri = [&](float3 a, float3 b, float3 c) { float4 orientation = CalculateOrientation(a, b, c); @@ -173,6 +193,10 @@ static void FillHeightFieldBuffer(VertexNoUv* ptr, std::size_t num, ptr[count].position = c; ptr[count].orientation = orientation; ++count; + + UpdateBounds(a, vmin, vmax); + UpdateBounds(b, vmin, vmax); + UpdateBounds(c, vmin, vmax); }; auto append_quad = [&](float3 a, float3 b, float3 c, float3 d) { append_tri(a, b, d); @@ -186,7 +210,7 @@ static void FillHeightFieldBuffer(VertexNoUv* ptr, std::size_t num, const float width = 0.5f * (ncol - 1); float sz[4]; for (int i = 0; i < 4; ++i) { - sz[i] = static_cast(model->hfield_size[4 * hfieldid + i]); + sz[i] = static_cast(model->hfield_size[4 * hfieldid + i]); } auto get_pos = [=](int r, int c) { @@ -254,8 +278,8 @@ static void FillHeightFieldBuffer(VertexNoUv* ptr, std::size_t num, } // Build the right edge. for (int row = 0; row < nrow - 1; ++row) { - const float3 a = get_pos(row + 1, ncol-1); - const float3 b = get_pos(row, ncol-1); + const float3 a = get_pos(row + 1, ncol - 1); + const float3 b = get_pos(row, ncol - 1); const float3 c = {b.x, b.y, -sz[3]}; const float3 d = {a.x, a.y, -sz[3]}; append_quad(a, b, c, d); @@ -270,8 +294,8 @@ static void FillHeightFieldBuffer(VertexNoUv* ptr, std::size_t num, } // Build the back edge. for (int col = 0; col < ncol - 1; ++col) { - const float3 a = get_pos(nrow-1, col + 1); - const float3 b = get_pos(nrow-1, col); + const float3 a = get_pos(nrow - 1, col + 1); + const float3 b = get_pos(nrow - 1, col); const float3 c = {b.x, b.y, -sz[3]}; const float3 d = {a.x, a.y, -sz[3]}; append_quad(a, b, c, d); @@ -286,9 +310,7 @@ static void FillHeightFieldBuffer(VertexNoUv* ptr, std::size_t num, const float x1 = sz[0] * ((col + 1) / base_width - 1.0f); const float y0 = sz[1] * ((row + 0) / base_height - 1.0f); const float y1 = sz[1] * ((row + 1) / base_height - 1.0f); - append_quad({x0, y0, -sz[3]}, - {x0, y1, -sz[3]}, - {x1, y1, -sz[3]}, + append_quad({x0, y0, -sz[3]}, {x0, y1, -sz[3]}, {x1, y1, -sz[3]}, {x1, y0, -sz[3]}); } } @@ -305,7 +327,7 @@ static int CalculateHeightFieldVertexCount(const mjModel* model, int hfieldid) { // we need. But, in general... // We use 4 triangles (i.e. 12 vertices) per quad. - const int surface_count = 12 * (nrow-1) * (ncol-1); + const int surface_count = 12 * (nrow - 1) * (ncol - 1); // We use 1 quad (i.e. 6 vertices) per edge element. We double this because // we have two edges per dimension (e.g. left/right and front/back). const int edge_count = (12 * (nrow - 1)) + (12 * (ncol - 1)); @@ -322,17 +344,23 @@ template static filament::VertexBuffer* CreateVertexBuffer(filament::Engine* engine, const mjModel* model, int id, int vertex_count, - FillFn fill_fn) { - return CreateVertexBuffer( + FillFn fill_fn, + filament::Box* bounds) { + float3 vmin = {FLT_MAX, FLT_MAX, FLT_MAX}; + float3 vmax = {-FLT_MAX, -FLT_MAX, -FLT_MAX}; + filament::VertexBuffer* buffer = CreateVertexBuffer( engine, vertex_count, [&](std::byte* buffer, std::size_t num_bytes) { auto* ptr = reinterpret_cast(buffer); - fill_fn(ptr, num_bytes / sizeof(T), model, id); + fill_fn(ptr, num_bytes / sizeof(T), model, id, &vmin, &vmax); }); + bounds->set(vmin, vmax); + return buffer; } filament::VertexBuffer* CreateVertexBuffer(filament::Engine* engine, const mjModel* model, int id, - MeshType mesh_type) { + MeshType mesh_type, + filament::Box* bounds) { if (id < 0) { mju_error("Invalid mesh index %d", id); return nullptr; @@ -376,11 +404,13 @@ filament::VertexBuffer* CreateVertexBuffer(filament::Engine* engine, switch (mesh_type) { case MeshType::kNormal: return CreateVertexBuffer(engine, model, id, vertex_count, - FillMeshBuffer); + FillMeshBuffer, + bounds); break; case MeshType::kConvexHull: return CreateVertexBuffer(engine, model, id, vertex_count, - FillConvexHullBuffer); + FillConvexHullBuffer, + bounds); break; case MeshType::kHeightField: mju_error("Height fields do not support UV coordinates."); @@ -391,15 +421,17 @@ filament::VertexBuffer* CreateVertexBuffer(filament::Engine* engine, switch (mesh_type) { case MeshType::kNormal: return CreateVertexBuffer(engine, model, id, vertex_count, - FillMeshBuffer); + FillMeshBuffer, + bounds); break; case MeshType::kConvexHull: return CreateVertexBuffer(engine, model, id, vertex_count, - FillConvexHullBuffer); + FillConvexHullBuffer, + bounds); break; case MeshType::kHeightField: return CreateVertexBuffer(engine, model, id, vertex_count, - FillHeightFieldBuffer); + FillHeightFieldBuffer, bounds); break; } } @@ -453,23 +485,4 @@ filament::IndexBuffer* CreateIndexBuffer(filament::Engine* engine, FillSequence); } } - -filament::Texture* CreateTexture(filament::Engine* engine, const mjModel* model, - int id, TextureType texture_type) { - if (id < 0 || id >= model->ntex) { - mju_error("Invalid texture index %d", id); - } - - const int width = model->tex_width[id]; - const int height = model->tex_height[id]; - const bool is_srgb = model->tex_colorspace[id] == mjCOLORSPACE_SRGB; - const int num_channels = model->tex_nchannel[id]; - const mjtByte* data = model->tex_data + model->tex_adr[id]; - filament::Texture* texture = - texture_type == TextureType::kNormal2d - ? Create2dTexture(engine, width, height, num_channels, data, is_srgb) - : CreateCubeTexture(engine, width, height, num_channels, data, - is_srgb); - return texture; -} } // namespace mujoco diff --git a/src/experimental/filament/filament/model_util.h b/src/experimental/filament/filament/model_util.h index 46f9beec10..edf14fedd0 100644 --- a/src/experimental/filament/filament/model_util.h +++ b/src/experimental/filament/filament/model_util.h @@ -17,6 +17,7 @@ #include +#include #include #include #include @@ -36,26 +37,17 @@ enum class MeshType { kHeightField, }; -// The types of textures stored in the mjModel. -enum class TextureType { - kNormal2d, - kCube, -}; - // Generates a filament VertexBuffer for a given mesh in the mjModel. filament::VertexBuffer* CreateVertexBuffer(filament::Engine* engine, const mjModel* model, int id, - MeshType mesh_type); + MeshType mesh_type, + filament::Box* bounds); // Generates a filament IndexBuffer for a given mesh in the mjModel. filament::IndexBuffer* CreateIndexBuffer(filament::Engine* engine, const mjModel* model, int id, MeshType mesh_type); -// Generates a filament Texture for a given 2D texture in the mjModel. -filament::Texture* CreateTexture(filament::Engine* engine, const mjModel* model, - int id, TextureType texture_type); - // Reads a value with the given name from the mjModel's data sections. The // default_value is returned if the named element is not found. template diff --git a/src/experimental/filament/filament/object_manager.cc b/src/experimental/filament/filament/object_manager.cc index 4aedb12034..80679284d3 100644 --- a/src/experimental/filament/filament/object_manager.cc +++ b/src/experimental/filament/filament/object_manager.cc @@ -14,12 +14,11 @@ #include "experimental/filament/filament/object_manager.h" -#include +#include #include +#include #include #include -#include -#include #include #include @@ -29,11 +28,7 @@ #include #include #include -#include "experimental/filament/filament/buffer_util.h" -#include "experimental/filament/filament/builtins.h" -#include "experimental/filament/filament/model_util.h" -#include "experimental/filament/filament/texture_util.h" -#include "experimental/filament/render_context_filament.h" +#include "experimental/filament/filament/texture.h" #include "user/user_resource.h" namespace mujoco { @@ -64,23 +59,8 @@ struct Asset { } // namespace -ObjectManager::ObjectManager(const mjModel* model, filament::Engine* engine) - : model_(model), engine_(engine) { - const int nquad = model->vis.quality.numquads; - const int nstack = model->vis.quality.numstacks; - const int nslice = model->vis.quality.numslices; - - shapes_[kLine] = CreateLine(engine_); - shapes_[kBox] = CreateBox(engine_, nquad); - shapes_[kLineBox] = CreateLineBox(engine_); - shapes_[kCone] = CreateCone(engine_, nstack, nslice); - shapes_[kDisk] = CreateDisk(engine_, nslice); - shapes_[kDome] = CreateDome(engine_, nstack / 2, nslice); - shapes_[kTube] = CreateTube(engine_, nstack, nslice); - shapes_[kPlane] = CreatePlane(engine_, nquad); - shapes_[kSphere] = CreateSphere(engine_, nstack, nslice); - shapes_[kTriangle] = CreateTriangle(engine_); - +ObjectManager::ObjectManager(filament::Engine* engine) + : engine_(engine) { auto LoadMaterial = [this](std::string_view filename) { Asset asset(filename); filament::Material::Builder material_builder; @@ -107,151 +87,57 @@ ObjectManager::ObjectManager(const mjModel* model, filament::Engine* engine) materials_[kUnlitDepth] = LoadMaterial("unlit_depth.filamat"); materials_[kUnlitUi] = LoadMaterial("unlit_ui.filamat"); - for (int i = 0; i < model_->ntex; ++i) { - UploadTexture(model_, i); - } - for (int i = 0; i < model_->nmesh; ++i) { - UploadMesh(model_, i); - } - for (int i = 0; i < model_->nhfield; ++i) { - UploadHeightField(model_, i); - } - static uint8_t black_rgb[3] = {0, 0, 0}; - fallback_black_ = Create2dTexture(engine_, 1, 1, 3, black_rgb, false); static uint8_t white_rgb[3] = {255, 255, 255}; - fallback_white_ = Create2dTexture(engine_, 1, 1, 3, white_rgb, false); static uint8_t normal_data[3] = {128, 128, 255}; - fallback_normal_ = Create2dTexture(engine_, 1, 1, 3, normal_data, false); static uint8_t orm_data[3] = {0, 255, 0}; - fallback_orm_ = Create2dTexture(engine_, 1, 1, 3, orm_data, false); - fallback_textures_[mjTEXROLE_USER] = fallback_black_; - fallback_textures_[mjTEXROLE_RGB] = fallback_white_; - fallback_textures_[mjTEXROLE_OCCLUSION] = fallback_white_; - fallback_textures_[mjTEXROLE_ROUGHNESS] = fallback_white_; - fallback_textures_[mjTEXROLE_METALLIC] = fallback_black_; - fallback_textures_[mjTEXROLE_NORMAL] = fallback_normal_; - fallback_textures_[mjTEXROLE_EMISSIVE] = fallback_black_; - fallback_textures_[mjTEXROLE_ORM] = fallback_orm_; + TextureConfig config; + DefaultTextureConfig(&config); + config.width = 1; + config.height = 1; + config.target = mjTEXTURE_2D; + config.format = mjPIXEL_FORMAT_RGB8; + config.color_space = mjCOLORSPACE_LINEAR; + + auto CreateFallbackTexture = [this, &config](uint8_t color[3]) { + auto texture = std::make_unique(engine_, config); + + TextureData payload; + DefaultTextureData(&payload); + payload.bytes = color; + payload.nbytes = 3; + payload.release_callback = nullptr; + payload.user_data = nullptr; + texture->Upload(payload); + return texture; + }; + + fallback_black_ = CreateFallbackTexture(black_rgb); + fallback_white_ = CreateFallbackTexture(white_rgb); + fallback_normal_ = CreateFallbackTexture(normal_data); + fallback_orm_ = CreateFallbackTexture(orm_data); - fallback_indirect_light_ = LoadFallbackIndirectLight("ibl.ktx", 1.0f); + fallback_textures_[mjTEXROLE_USER] = fallback_black_.get(); + fallback_textures_[mjTEXROLE_RGB] = fallback_white_.get(); + fallback_textures_[mjTEXROLE_OCCLUSION] = fallback_white_.get(); + fallback_textures_[mjTEXROLE_ROUGHNESS] = fallback_white_.get(); + fallback_textures_[mjTEXROLE_METALLIC] = fallback_black_.get(); + fallback_textures_[mjTEXROLE_NORMAL] = fallback_normal_.get(); + fallback_textures_[mjTEXROLE_EMISSIVE] = fallback_black_.get(); + fallback_textures_[mjTEXROLE_ORM] = fallback_orm_.get(); - specular_multiplier_ = ReadElement( - model_, "filament.phong.specular_multiplier", specular_multiplier_); - shininess_multiplier_ = ReadElement( - model_, "filament.phong.shininess_multiplier", shininess_multiplier_); - emissive_multiplier_ = ReadElement( - model_, "filament.phong.emissive_multiplier", emissive_multiplier_); + LoadFallbackIndirectLight("ibl.ktx", 1.0f); } ObjectManager::~ObjectManager() { - for (auto& iter : skyboxes_) { - engine_->destroy(iter); - } - for (auto& iter : indirect_lights_) { - engine_->destroy(iter); + if (fallback_indirect_light_) { + engine_->destroy(fallback_indirect_light_); } + fallback_indirect_light_texture_.reset(); for (auto& iter : materials_) { engine_->destroy(iter); } - for (auto& iter : meshes_) { - engine_->destroy(iter.second.vertex_buffer); - engine_->destroy(iter.second.index_buffer); - } - for (auto& iter : shapes_) { - engine_->destroy(iter.vertex_buffer); - engine_->destroy(iter.index_buffer); - } - for (auto& iter : textures_) { - engine_->destroy(iter.second); - } - // fallback_textures_ maps to these textures. - engine_->destroy(fallback_white_); - engine_->destroy(fallback_black_); - engine_->destroy(fallback_normal_); - engine_->destroy(fallback_orm_); -} - -void ObjectManager::UploadMesh(const mjModel* model, int id) { - if (model != model_) { - mju_error("Model mismatch."); - } - if (id < 0 || id >= model->nmesh) { - mju_error("Invalid mesh index %d", id); - } - - if (auto iter = meshes_.find(id); iter != meshes_.end()) { - engine_->destroy(iter->second.vertex_buffer); - engine_->destroy(iter->second.index_buffer); - } - if (auto iter = convex_hulls_.find(id); iter != convex_hulls_.end()) { - engine_->destroy(iter->second.vertex_buffer); - engine_->destroy(iter->second.index_buffer); - } - - FilamentBuffers& buffers = meshes_[id]; - buffers.vertex_buffer = - CreateVertexBuffer(engine_, model, id, MeshType::kNormal); - buffers.index_buffer = - CreateIndexBuffer(engine_, model, id, MeshType::kNormal); - - if (model->mesh_graphadr[id] >= 0) { - FilamentBuffers& hull_buffers = convex_hulls_[id]; - hull_buffers.vertex_buffer = - CreateVertexBuffer(engine_, model, id, MeshType::kConvexHull); - hull_buffers.index_buffer = - CreateIndexBuffer(engine_, model, id, MeshType::kConvexHull); - } -} - -void ObjectManager::UploadTexture(const mjModel* model, int id) { - if (model != model_) { - mju_error("Model mismatch."); - } - if (id < 0 || id >= model->ntex) { - mju_error("Invalid texture index: %d", id); - } - - if (auto iter = textures_.find(id); iter != textures_.end()) { - engine_->destroy(iter->second); - } - - const int texture_type = model->tex_type[id]; - if (model->tex_height[id] == 1) { - const mjtByte* bytes = model->tex_data + model->tex_adr[id]; - const int num_bytes = model->tex_width[id]; - textures_[id] = - CreateKtxTexture(engine_, bytes, num_bytes, spherical_harmonics_[id]); - } else if (texture_type == mjTEXTURE_2D) { - textures_[id] = CreateTexture(engine_, model, id, TextureType::kNormal2d); - } else if (texture_type == mjTEXTURE_CUBE) { - textures_[id] = CreateTexture(engine_, model, id, TextureType::kCube); - } else if (texture_type == mjTEXTURE_SKYBOX) { - textures_[id] = CreateTexture(engine_, model, id, TextureType::kCube); - } else { - mju_error("Unsupported: Texture type: %d", texture_type); - } -} - -void ObjectManager::UploadHeightField(const mjModel* model, int id) { - if (model != model_) { - mju_error("Model mismatch."); - } - if (id < 0 || id >= model->nhfield) { - mju_error("Invalid height field index %d", id); - } - - if (auto iter = height_fields_.find(id); iter != height_fields_.end()) { - engine_->destroy(iter->second.vertex_buffer); - engine_->destroy(iter->second.index_buffer); - } - - FilamentBuffers& buffers = height_fields_[id]; - buffers.vertex_buffer = - CreateVertexBuffer(engine_, model, id, MeshType::kHeightField); - buffers.index_buffer = - CreateIndexBuffer(engine_, model, id, MeshType::kHeightField); } filament::Material* ObjectManager::GetMaterial(MaterialType type) const { @@ -261,137 +147,70 @@ filament::Material* ObjectManager::GetMaterial(MaterialType type) const { return materials_[type]; } -const FilamentBuffers* ObjectManager::GetMeshBuffer(int data_id) const { - // As defined by mjv_updateScene: - // original mesh: mesh_id * 2 - // convex hull: (mesh_id * 2) + 1 - const int mesh_id = data_id / 2; - if (data_id % 2 == 0) { - auto it = meshes_.find(mesh_id); - return it != meshes_.end() ? &it->second : nullptr; - } else { - auto it = convex_hulls_.find(mesh_id); - return it != convex_hulls_.end() ? &it->second : nullptr; - } -} - -const FilamentBuffers* ObjectManager::GetHeightFieldBuffer( - int hfield_id) const { - auto it = height_fields_.find(hfield_id); - return it != height_fields_.end() ? &it->second : nullptr; -} - -const FilamentBuffers* ObjectManager::GetShapeBuffer(ShapeType shape) const { - if (shape < 0 || shape >= kNumShapes) { - mju_error("Invalid shape type: %d", shape); - } - return &shapes_[shape]; -} - -const filament::Texture* ObjectManager::GetTexture(int tex_id) const { - auto it = textures_.find(tex_id); - return it != textures_.end() ? it->second : nullptr; -} - -const filament::Texture* ObjectManager::GetTexture(int mat_id, int role) const { - if (mat_id < 0 || mat_id >= model_->nmat || role < 0 || role >= mjNTEXROLE) { - return nullptr; - } - const int tex_id = model_->mat_texid[mat_id * mjNTEXROLE + role]; - return GetTexture(tex_id); -} - -const filament::Texture* ObjectManager::GetTextureWithFallback(int mat_id, - int role) const { - if (auto texture = GetTexture(mat_id, role)) { - return texture; - } - return GetFallbackTexture(role); -} - -const filament::Texture* ObjectManager::GetFallbackTexture(int role) const { - auto iter = fallback_textures_.find(role); - if (iter != fallback_textures_.end()) { - return iter->second; +const Texture* ObjectManager::GetFallbackTexture( + mjtTextureRole role) const { + if (role < 0 || role >= mjNTEXROLE) { + mju_error("Invalid texture role: %d", role); } - return nullptr; + return fallback_textures_[role]; } filament::IndirectLight* ObjectManager::GetFallbackIndirectLight() { return fallback_indirect_light_; } -filament::IndirectLight* ObjectManager::CreateIndirectLight(int tex_id, - float intensity) { - filament::Texture* texture = nullptr; - auto texture_iter = textures_.find(tex_id); - if (texture_iter != textures_.end()) { - texture = texture_iter->second; +void ObjectManager::LoadFallbackIndirectLight( + std::string_view filename, float intensity) { + fallback_indirect_light_texture_.reset(); + if (fallback_indirect_light_ != nullptr) { + engine_->destroy(fallback_indirect_light_); + fallback_indirect_light_ = nullptr; } - if (texture == nullptr) { - return nullptr; + Asset* asset = new Asset(filename); + auto release_asset = +[](void* user_data) { + delete static_cast(user_data); + }; + if (asset->size == 0) { + release_asset(asset); + return; } - SphericalHarmonics* spherical_harmonics = nullptr; - auto sh_iter = spherical_harmonics_.find(tex_id); - if (sh_iter != spherical_harmonics_.end()) { - spherical_harmonics = &sh_iter->second; - } + TextureConfig config; + DefaultTextureConfig(&config); + config.width = 1; + config.height = 1; + config.target = mjTEXTURE_CUBE; + config.format = mjPIXEL_FORMAT_KTX; + config.color_space = mjCOLORSPACE_AUTO; - return CreateIndirectLight(texture, spherical_harmonics, intensity); -} + fallback_indirect_light_texture_ = std::make_unique(engine_, config); -filament::IndirectLight* ObjectManager::LoadFallbackIndirectLight( - std::string_view filename, float intensity) { - Asset asset(filename); - if (asset.size == 0) { - return nullptr; + TextureData payload; + DefaultTextureData(&payload); + payload.bytes = asset->payload; + payload.nbytes = static_cast(asset->size); + payload.release_callback = release_asset; + payload.user_data = asset; + + fallback_indirect_light_texture_->Upload(payload); + if (fallback_indirect_light_texture_ == nullptr) { + return; } - filament::math::float3 spherical_harmonics[9]; - filament::Texture* tex = - CreateKtxTexture(engine_, reinterpret_cast(asset.payload), - asset.size, spherical_harmonics); - return CreateIndirectLight(tex, &spherical_harmonics, intensity); -} + const Texture::SphericalHarmonics* spherical_harmonics = + fallback_indirect_light_texture_->GetSphericalHarmonics(); -filament::IndirectLight* ObjectManager::CreateIndirectLight( - filament::Texture* texture, SphericalHarmonics* spherical_harmonics, - float intensity) { + // Build the indirect light. filament::IndirectLight::Builder builder; - builder.reflections(texture); - if (spherical_harmonics != nullptr) { + builder.reflections(fallback_indirect_light_texture_->GetFilamentTexture()); + if (spherical_harmonics) { builder.irradiance(3, *spherical_harmonics); } builder.intensity(intensity); // Rotate the light to match mujoco's Z-up convention. builder.rotation(filament::math::mat3f::rotation( filament::math::f::PI / 2, filament::math::float3{1, 0, 0})); - filament::IndirectLight* indirect_light = builder.build(*engine_); - indirect_lights_.push_back(indirect_light); - return indirect_light; + fallback_indirect_light_ = builder.build(*engine_); } - -filament::Skybox* ObjectManager::CreateSkybox() { - filament::Texture* skybox_texture = nullptr; - for (auto& iter : textures_) { - const int texture_type = model_->tex_type[iter.first]; - if (texture_type == mjTEXTURE_SKYBOX) { - skybox_texture = iter.second; - break; - } - } - - if (skybox_texture == nullptr) { - return nullptr; - } - - filament::Skybox::Builder builder; - builder.environment(skybox_texture); - filament::Skybox* skybox = builder.build(*engine_); - skyboxes_.push_back(skybox); - return skybox; -} - } // namespace mujoco diff --git a/src/experimental/filament/filament/object_manager.h b/src/experimental/filament/filament/object_manager.h index 6fe51dccbe..523203479d 100644 --- a/src/experimental/filament/filament/object_manager.h +++ b/src/experimental/filament/filament/object_manager.h @@ -16,25 +16,21 @@ #define MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_OBJECT_MANAGER_H_ #include -#include +#include #include -#include -#include #include #include #include -#include -#include -#include "experimental/filament/filament/buffer_util.h" -#include "experimental/filament/render_context_filament.h" +#include +#include "experimental/filament/filament/texture.h" namespace mujoco { // Creates and owns various filament objects based on the data in a mjrContext. class ObjectManager { public: - ObjectManager(const mjModel* model, filament::Engine* engine); + ObjectManager(filament::Engine* engine); ~ObjectManager(); enum MaterialType { @@ -59,85 +55,34 @@ class ObjectManager { kNumMaterials, }; - enum ShapeType { - kLine, - kLineBox, - kPlane, - kTriangle, - kBox, - kSphere, - kCone, - kDisk, - kDome, - kTube, - kNumShapes, - }; - - using SphericalHarmonics = filament::math::float3[9]; - - void UploadMesh(const mjModel* model, int id); - - void UploadTexture(const mjModel* model, int id); - - void UploadHeightField(const mjModel* model, int id); - - // Returns the filament engine used by the ObjectManager to create filament - // objects. + // Returns the filament Engine that owns the assets. filament::Engine* GetEngine() const { return engine_; } + // Returns the Material of the given type. filament::Material* GetMaterial(MaterialType type) const; - // Returns the cached instance of a filament object created from the mjModel. - const FilamentBuffers* GetMeshBuffer(int data_id) const; - const FilamentBuffers* GetShapeBuffer(ShapeType shape) const; - const FilamentBuffers* GetHeightFieldBuffer(int hfield_id) const; - const filament::Texture* GetTexture(int tex_id) const; - const filament::Texture* GetTexture(int mat_id, int role) const; - const filament::Texture* GetTextureWithFallback(int mat_id, int role) const; - const filament::Texture* GetFallbackTexture(int role) const; - filament::IndirectLight* GetFallbackIndirectLight(); - - // Creates and returns a new instance of a filament object. The objects are - // owned by the ObjectManager and will be deleted in the destructor. - filament::Skybox* CreateSkybox(); - filament::IndirectLight* CreateIndirectLight( - filament::Texture* texture, SphericalHarmonics* spherical_harmonics, - float intensity); - filament::IndirectLight* CreateIndirectLight(int tex_id, float intensity); - filament::IndirectLight* LoadFallbackIndirectLight(std::string_view filename, - float intensity); + // Returns the fallback Texture with the given role. + const Texture* GetFallbackTexture(mjtTextureRole role) const; - float GetSpecularMultiplier() const { return specular_multiplier_; } - float GetShininessMultiplier() const { return shininess_multiplier_; } - float GetEmissiveMultiplier() const { return emissive_multiplier_; } + // Returns the fallback IndirectLight. + filament::IndirectLight* GetFallbackIndirectLight(); - const mjModel* GetModel() const { return model_; } + // Loads an indirect light from a file, setting it to the fallback. + void LoadFallbackIndirectLight(std::string_view filename, float intensity); ObjectManager(const ObjectManager&) = delete; ObjectManager& operator=(const ObjectManager&) = delete; private: - const mjModel* model_ = nullptr; filament::Engine* engine_ = nullptr; - - std::array shapes_; std::array materials_; - std::vector skyboxes_; - std::vector indirect_lights_; - std::unordered_map meshes_; - std::unordered_map convex_hulls_; - std::unordered_map height_fields_; - std::unordered_map textures_; - std::unordered_map spherical_harmonics_; - std::unordered_map fallback_textures_; - filament::Texture* fallback_white_ = nullptr; - filament::Texture* fallback_black_ = nullptr; - filament::Texture* fallback_normal_ = nullptr; - filament::Texture* fallback_orm_ = nullptr; + std::array fallback_textures_; + std::unique_ptr fallback_white_ = nullptr; + std::unique_ptr fallback_black_ = nullptr; + std::unique_ptr fallback_normal_ = nullptr; + std::unique_ptr fallback_orm_ = nullptr; + std::unique_ptr fallback_indirect_light_texture_ = nullptr; filament::IndirectLight* fallback_indirect_light_ = nullptr; - float specular_multiplier_ = 0.2f; - float shininess_multiplier_ = 0.1f; - float emissive_multiplier_ = 0.3f; }; } // namespace mujoco diff --git a/src/experimental/filament/filament/render_target_util.cc b/src/experimental/filament/filament/render_target_util.cc index 52cbc25cbc..2e4ce391f2 100644 --- a/src/experimental/filament/filament/render_target_util.cc +++ b/src/experimental/filament/filament/render_target_util.cc @@ -14,47 +14,15 @@ #include "experimental/filament/filament/render_target_util.h" +#include + #include #include #include -#include +#include "experimental/filament/filament/texture.h" namespace mujoco { -static filament::Texture* CreateRenderTargetTexture( - filament::Engine* engine, int width, int height, - RenderTargetTextureType type) { - filament::Texture::Builder builder; - builder.width(width); - builder.height(height); - switch (type) { - case kRenderTargetColor: - builder.usage(filament::Texture::Usage::COLOR_ATTACHMENT | - filament::Texture::Usage::BLIT_SRC); - builder.format(filament::Texture::InternalFormat::RGB8); - break; - case kRenderTargetDepth: - builder.usage(filament::Texture::Usage::DEPTH_ATTACHMENT | - filament::Texture::Usage::SAMPLEABLE); - builder.format(filament::Texture::InternalFormat::DEPTH32F); - break; - case kRenderTargetDepthColor: - builder.usage(filament::Texture::Usage::COLOR_ATTACHMENT | - filament::Texture::Usage::BLIT_SRC); - builder.format(filament::Texture::InternalFormat::R32F); - break; - case kRenderTargetReflectionColor: - builder.usage(filament::Texture::Usage::COLOR_ATTACHMENT | - filament::Texture::Usage::BLIT_SRC | - filament::Texture::Usage::SAMPLEABLE); - builder.format(filament::Texture::InternalFormat::RGBA8); - break; - default: - mju_error("Unknown type: %d", static_cast(type)); - } - return builder.build(*engine); -} - RenderTargetAndTextures::RenderTargetAndTextures(filament::Engine* engine, RenderTargetTextureType color, RenderTargetTextureType depth) @@ -73,15 +41,15 @@ void RenderTargetAndTextures::Prepare(int width, int height) { height_ = height; color_texture_ = - CreateRenderTargetTexture(engine_, width, height, color_type_); + std::make_unique(engine_, color_type_, width, height); depth_texture_ = - CreateRenderTargetTexture(engine_, width, height, depth_type_); + std::make_unique(engine_, depth_type_, width, height); filament::RenderTarget::Builder builder; builder.texture(filament::RenderTarget::AttachmentPoint::COLOR, - color_texture_); + color_texture_->GetFilamentTexture()); builder.texture(filament::RenderTarget::AttachmentPoint::DEPTH, - depth_texture_); + depth_texture_->GetFilamentTexture()); render_target_ = builder.build(*engine_); } @@ -90,14 +58,20 @@ void RenderTargetAndTextures::Destroy() { engine_->destroy(render_target_); render_target_ = nullptr; } - if (color_texture_) { - engine_->destroy(color_texture_); - color_texture_ = nullptr; - } - if (depth_texture_) { - engine_->destroy(depth_texture_); - depth_texture_ = nullptr; - } + color_texture_.reset(); + depth_texture_.reset(); +} + +Texture* RenderTargetAndTextures::GetColorTexture() const { + return color_texture_.get(); +} + +Texture* RenderTargetAndTextures::GetDepthTexture() const { + return depth_texture_.get(); +} + +filament::RenderTarget* RenderTargetAndTextures::GetRenderTarget() const { + return render_target_; } } // namespace mujoco diff --git a/src/experimental/filament/filament/render_target_util.h b/src/experimental/filament/filament/render_target_util.h index e8dfb9392e..ce6b96dcc9 100644 --- a/src/experimental/filament/filament/render_target_util.h +++ b/src/experimental/filament/filament/render_target_util.h @@ -15,21 +15,14 @@ #ifndef MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_RENDER_TARGET_UTIL_H_ #define MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_RENDER_TARGET_UTIL_H_ +#include + #include #include +#include "experimental/filament/filament/texture.h" namespace mujoco { -// The different types of textures we can create for a render target. -enum RenderTargetTextureType { - kRenderTargetNone, - kRenderTargetColor, - kRenderTargetDepth, - kRenderTargetDepthColor, - kRenderTargetReflectionColor, - kNumRenderTargetTextureTypes, -}; - // Manages a filament RenderTarget and the textures which are bound to it. class RenderTargetAndTextures { public: @@ -48,23 +41,23 @@ class RenderTargetAndTextures { void Prepare(int width, int height); // Returns the color texture. - filament::Texture* GetColorTexture() const { return color_texture_; } + Texture* GetColorTexture() const; // Returns the depth texture. - filament::Texture* GetDepthTexture() const { return depth_texture_; } + Texture* GetDepthTexture() const; // Returns the render target. - filament::RenderTarget* GetRenderTarget() const { return render_target_; } + filament::RenderTarget* GetRenderTarget() const; private: void Destroy(); filament::Engine* engine_ = nullptr; - filament::Texture* color_texture_ = nullptr; - filament::Texture* depth_texture_ = nullptr; filament::RenderTarget* render_target_ = nullptr; - RenderTargetTextureType color_type_ = kRenderTargetNone; - RenderTargetTextureType depth_type_ = kRenderTargetNone; + std::unique_ptr color_texture_ = nullptr; + std::unique_ptr depth_texture_ = nullptr; + RenderTargetTextureType color_type_; + RenderTargetTextureType depth_type_; int width_ = 0; int height_ = 0; }; diff --git a/src/experimental/filament/filament/renderables.cc b/src/experimental/filament/filament/renderables.cc index 0eeb0c87dc..fcce336ff7 100644 --- a/src/experimental/filament/filament/renderables.cc +++ b/src/experimental/filament/filament/renderables.cc @@ -105,13 +105,16 @@ utils::Entity Renderables::CreateEntity(const FilamentBuffers& buffers) { if (material_instance_) { builder.material(0, material_instance_); } - builder.boundingBox(buffers.bounds) - .culling(false) - .castShadows(cast_shadows_) - .receiveShadows(receive_shadows_) - .layerMask(0xff, layer_mask_) - .priority(priority_) - .screenSpaceContactShadows(true); + if (buffers.bounds.has_value()) { + builder.boundingBox(buffers.bounds.value()); + } else { + builder.culling(false); + } + builder.castShadows(cast_shadows_); + builder.receiveShadows(receive_shadows_); + builder.layerMask(0xff, layer_mask_); + builder.priority(priority_); + builder.screenSpaceContactShadows(true);; builder.build(*engine_, entity); if (assigned_scene_) { diff --git a/src/experimental/filament/filament/scene_view.cc b/src/experimental/filament/filament/scene_view.cc index cba553955c..7009dd176c 100644 --- a/src/experimental/filament/filament/scene_view.cc +++ b/src/experimental/filament/filament/scene_view.cc @@ -46,9 +46,11 @@ #include "experimental/filament/filament/gui_view.h" #include "experimental/filament/filament/light.h" #include "experimental/filament/filament/math_util.h" +#include "experimental/filament/filament/model_objects.h" #include "experimental/filament/filament/model_util.h" #include "experimental/filament/filament/object_manager.h" #include "experimental/filament/filament/render_target_util.h" +#include "experimental/filament/filament/texture.h" namespace mujoco { @@ -113,37 +115,38 @@ static void SetupReflectionCamera(const mat4& surface_xform, reflection_camera->setCustomProjection(oblique, near, far); } -SceneView::SceneView(filament::Engine* engine, ObjectManager* object_mgr) - : object_mgr_(object_mgr), engine_(engine) { - scene_ = engine_->createScene(); - camera_ = engine_->createCamera(utils::EntityManager::get().create()); - reflect_camera_ = engine_->createCamera(utils::EntityManager::get().create()); +SceneView::SceneView(ObjectManager* object_mgr, const mjModel* model) + : object_mgr_(object_mgr) { + filament::Engine* engine = object_mgr_->GetEngine(); + model_objects_ = std::make_unique(model, engine); + + scene_ = engine->createScene(); + camera_ = engine->createCamera(utils::EntityManager::get().create()); + reflect_camera_ = engine->createCamera(utils::EntityManager::get().create()); for (auto& view : views_) { - view = engine_->createView(); + view = engine->createView(); view->setScene(scene_); view->setCamera(camera_); } - reflect_view_ = engine_->createView(); + reflect_view_ = engine->createView(); reflect_view_->setScene(scene_); reflect_view_->setCamera(reflect_camera_); reflect_view_->setShadowingEnabled(false); reflect_view_->setPostProcessingEnabled(false); - const mjModel* m = object_mgr_->GetModel(); - // Configure options for the normal view. auto& cg = color_grading_options_; - cg.exposure = ReadElement(m, "filament.out.exposure", cg.exposure); - cg.contrast = ReadElement(m, "filament.out.contrast", cg.contrast); - cg.vibrance = ReadElement(m, "filament.out.vibrance", cg.vibrance); - cg.saturation = ReadElement(m, "filament.out.saturation", cg.saturation); - cg.temperature = ReadElement(m, "filament.out.temperature", cg.temperature); - cg.tint = ReadElement(m, "filament.out.tint", cg.tint); + cg.exposure = ReadElement(model, "filament.out.exposure", cg.exposure); + cg.contrast = ReadElement(model, "filament.out.contrast", cg.contrast); + cg.vibrance = ReadElement(model, "filament.out.vibrance", cg.vibrance); + cg.saturation = ReadElement(model, "filament.out.saturation", cg.saturation); + cg.temperature = ReadElement(model, "filament.out.temperature", cg.temperature); + cg.tint = ReadElement(model, "filament.out.tint", cg.tint); auto tone_mapping = - ReadElement(m, "filament.out.tone_mapping"); + ReadElement(model, "filament.out.tone_mapping"); if (tone_mapping == "aces") { cg.tone_mapper = ToneMapperType::kACES; } else if (tone_mapping == "aces_legacy") { @@ -158,9 +161,9 @@ SceneView::SceneView(filament::Engine* engine, ObjectManager* object_mgr) SetColorGradingOptions(cg); auto ao = views_[kNormalIndex]->getAmbientOcclusionOptions(); - ao.enabled = ReadElement(m, "filament.ao.enabled", true); - ao.bentNormals = ReadElement(m, "filament.ao.bent_normals", false); - ao.ssct.enabled = ReadElement(m, "filament.ao.ssct", ao.ssct.enabled); + ao.enabled = ReadElement(model, "filament.ao.enabled", true); + ao.bentNormals = ReadElement(model, "filament.ao.bent_normals", false); + ao.ssct.enabled = ReadElement(model, "filament.ao.ssct", ao.ssct.enabled); ao.quality = filament::QualityLevel::ULTRA; ao.lowPassFilter = filament::QualityLevel::ULTRA; ao.upsampling = filament::QualityLevel::ULTRA; @@ -168,16 +171,16 @@ SceneView::SceneView(filament::Engine* engine, ObjectManager* object_mgr) views_[kNormalIndex]->setAmbientOcclusionOptions(ao); auto msaa = views_[kNormalIndex]->getMultiSampleAntiAliasingOptions(); - msaa.enabled = ReadElement(m, "filament.msaa.enabled", true); + msaa.enabled = ReadElement(model, "filament.msaa.enabled", true); views_[kNormalIndex]->setMultiSampleAntiAliasingOptions(msaa); default_shadow_map_size_ = ReadElement( - m, "filament.shadows.map_size", default_shadow_map_size_); + model, "filament.shadows.map_size", default_shadow_map_size_); default_vsm_blur_width_ = ReadElement( - m, "filament.shadows.vsm_blur_width", default_vsm_blur_width_); + model, "filament.shadows.vsm_blur_width", default_vsm_blur_width_); auto shadow_type = views_[kNormalIndex]->getShadowType(); - shadow_type = ReadElement(m, "filament.shadows.type", shadow_type); + shadow_type = ReadElement(model, "filament.shadows.type", shadow_type); views_[kNormalIndex]->setShadowType(shadow_type); // Disable post processing for the depth and segmentation views to preserve @@ -190,44 +193,44 @@ SceneView::SceneView(filament::Engine* engine, ObjectManager* object_mgr) auto& tm = engine->getTransformManager(); tm.create(fog); auto rotation_axis = ReadElement( - m, "filament.fog.rotation_axis", float3{-1, 0, 0}); + model, "filament.fog.rotation_axis", float3{-1, 0, 0}); tm.setTransform(tm.getInstance(fog), mat4::rotation(filament::math::f::PI / 2, rotation_axis)); auto fog_opts = views_[kNormalIndex]->getFogOptions(); - fog_opts.enabled = ReadElement(m, "filament.fog.enabled", fog_opts.enabled); - fog_opts.color = ReadElement(m, "filament.fog.color", fog_opts.color); + fog_opts.enabled = + ReadElement(model, "filament.fog.enabled", fog_opts.enabled); + fog_opts.color = ReadElement(model, "filament.fog.color", fog_opts.color); fog_opts.distance = ReadElement( - m, "filament.fog.distance", fog_opts.distance); + model, "filament.fog.distance", fog_opts.distance); fog_opts.density = ReadElement( - m, "filament.fog.density", fog_opts.density); + model, "filament.fog.density", fog_opts.density); fog_opts.cutOffDistance = ReadElement( - m, "filament.fog.cutOffDistance", fog_opts.cutOffDistance); + model, "filament.fog.cutOffDistance", fog_opts.cutOffDistance); fog_opts.maximumOpacity = ReadElement( - m, "filament.fog.maximumOpacity", fog_opts.maximumOpacity); - fog_opts.height = ReadElement(m, "filament.fog.height", fog_opts.height); + model, "filament.fog.maximumOpacity", fog_opts.maximumOpacity); + fog_opts.height = ReadElement(model, "filament.fog.height", fog_opts.height); fog_opts.heightFalloff = ReadElement( - m, "filament.fog.heightFalloff", fog_opts.heightFalloff); + model, "filament.fog.heightFalloff", fog_opts.heightFalloff); fog_opts.inScatteringStart = ReadElement( - m, "filament.fog.inScatteringStart", fog_opts.inScatteringStart); + model, "filament.fog.inScatteringStart", fog_opts.inScatteringStart); fog_opts.inScatteringSize = ReadElement( - m, "filament.fog.inScatteringSize", fog_opts.inScatteringSize); + model, "filament.fog.inScatteringSize", fog_opts.inScatteringSize); views_[kNormalIndex]->setFogOptions(fog_opts); fallback_head_light_intensity_ = - ReadElement(m, "filament.fallback.head_light_intensity", + ReadElement(model, "filament.fallback.head_light_intensity", fallback_head_light_intensity_); fallback_scene_light_intensity_ = - ReadElement(m, "filament.fallback.scene_light_intensity", + ReadElement(model, "filament.fallback.scene_light_intensity", fallback_scene_light_intensity_); fallback_environment_light_intensity_ = - ReadElement(m, "filament.fallback.environment_light_intensity", + ReadElement(model, "filament.fallback.environment_light_intensity", fallback_environment_light_intensity_); // Create an empty/black indirect light to ensure that the skybox is oriented // to respect mujoco's Z-up convention. - scene_->setIndirectLight( - object_mgr_->CreateIndirectLight(nullptr, nullptr, 100000)); + scene_->setIndirectLight(model_objects_->CreateIndirectLight(-1, 100000)); PrepareLights(); } @@ -237,15 +240,16 @@ SceneView::~SceneView() { drawables_.clear(); reflect_targets_.clear(); - engine_->destroyCameraComponent(reflect_camera_->getEntity()); - engine_->destroy(reflect_view_); + filament::Engine* engine = object_mgr_->GetEngine(); + engine->destroyCameraComponent(reflect_camera_->getEntity()); + engine->destroy(reflect_view_); - engine_->destroyCameraComponent(camera_->getEntity()); - engine_->destroy(views_[kNormalIndex]->getColorGrading()); + engine->destroyCameraComponent(camera_->getEntity()); + engine->destroy(views_[kNormalIndex]->getColorGrading()); for (auto& view : views_) { - engine_->destroy(view); + engine->destroy(view); } - engine_->destroy(scene_); + engine->destroy(scene_); } void SceneView::Render(filament::Renderer* renderer, DrawMode draw_mode, @@ -304,22 +308,23 @@ void SceneView::SetViewport(mjrRect viewport) { } void SceneView::SetColorGradingOptions(const ColorGradingOptions& opts) { + filament::Engine* engine = object_mgr_->GetEngine(); + auto tone_mapper = CreateToneMapper(opts.tone_mapper); auto color_grading = ToBuilder(color_grading_options_) .toneMapper(tone_mapper.get()) - .build(*engine_); + .build(*engine); views_[kNormalIndex]->setColorGrading(color_grading); - engine_->destroy(color_grading_); + engine->destroy(color_grading_); color_grading_ = color_grading; color_grading_options_ = opts; } void SceneView::SetEnvironmentLight(std::string_view filename, float intensity) { - auto* ibl = object_mgr_->LoadFallbackIndirectLight(filename, intensity); - if (ibl) { - scene_->setIndirectLight(ibl); - } + scene_->setIndirectLight(nullptr); + object_mgr_->LoadFallbackIndirectLight(filename, intensity); + scene_->setIndirectLight(object_mgr_->GetFallbackIndirectLight()); } void SceneView::SetFallbackEnvironmentLight(float intensity) { @@ -357,8 +362,9 @@ std::optional SceneView::ClipFromWorld(const float3& pos) const{ } void SceneView::PrepareLights() { - const mjModel* model = object_mgr_->GetModel(); - filament::Skybox* skybox = object_mgr_->CreateSkybox(); + filament::Engine* engine = object_mgr_->GetEngine(); + const mjModel* model = model_objects_->GetModel(); + filament::Skybox* skybox = model_objects_->CreateSkybox(); if (skybox) { scene_->setSkybox(skybox); } @@ -369,7 +375,7 @@ void SceneView::PrepareLights() { total_light_intensity += model->light_intensity[i]; if (model->light_type[i] == mjLIGHT_IMAGE) { - auto* indirect_light = object_mgr_->CreateIndirectLight( + auto* indirect_light = model_objects_->CreateIndirectLight( model->light_texid[i], model->light_intensity[i]); if (indirect_light) { scene_->setIndirectLight(indirect_light); @@ -390,7 +396,7 @@ void SceneView::PrepareLights() { params.spot_cone_angle = model->light_cutoff[i]; } - auto light_obj = std::make_unique(engine_, params); + auto light_obj = std::make_unique(engine, params); #ifndef __EMSCRIPTEN__ // TODO(b/458045799): Re-enable when lights work on glinux and chromebook. light_obj->AddToScene(scene_); @@ -408,7 +414,7 @@ void SceneView::PrepareLights() { params.type = mjLIGHT_DIRECTIONAL; params.castshadow = 0; params.intensity = 0; - auto light_obj = std::make_unique(engine_, params); + auto light_obj = std::make_unique(engine, params); #ifndef __EMSCRIPTEN__ // TODO(b/458045799): Re-enable when lights work on glinux and chromebook. light_obj->AddToScene(scene_); @@ -456,9 +462,10 @@ void SceneView::UpdateScene(const mjvScene* scene) { } } - auto drawable = std::make_unique(object_mgr_, *geom); + auto drawable = + std::make_unique(object_mgr_, model_objects_.get(), *geom); drawable->AddToScene(scene_); - drawable->Update(object_mgr_->GetModel(), scene, *geom); + drawable->Update(model_objects_->GetModel(), scene, *geom); if (drawable->IsReflective()) { AddReflectiveDrawable(drawable.get()); } @@ -507,9 +514,11 @@ void SceneView::AddReflectiveDrawable(Drawable* drawable) { // Ensure we have the same number of render targets as we do reflective // drawables. + filament::Engine* engine = object_mgr_->GetEngine(); while (reflect_targets_.size() < reflectives_.size()) { reflect_targets_.push_back(std::make_unique( - engine_, kRenderTargetReflectionColor, kRenderTargetDepth)); + engine, RenderTargetTextureType::kReflectionColor, + RenderTargetTextureType::kDepth)); } // Prepare a render target for the reflective drawable. @@ -519,7 +528,21 @@ void SceneView::AddReflectiveDrawable(Drawable* drawable) { drawable->UpdateReflectionTexture(target->GetColorTexture()); } -filament::Engine* SceneView::GetEngine() const { return engine_; } +void SceneView::UploadMesh(const mjModel* model, int id) { + model_objects_->UploadMesh(model, id); +} + +void SceneView::UploadTexture(const mjModel* model, int id) { + model_objects_->UploadTexture(model, id); +} + +void SceneView::UploadHeightField(const mjModel* model, int id) { + model_objects_->UploadHeightField(model, id); +} + +filament::Engine* SceneView::GetEngine() const { + return object_mgr_->GetEngine(); +} filament::View* SceneView::GetDefaultRenderView() { return views_[kNormalIndex]; diff --git a/src/experimental/filament/filament/scene_view.h b/src/experimental/filament/filament/scene_view.h index a46a54d29d..8ced8be13e 100644 --- a/src/experimental/filament/filament/scene_view.h +++ b/src/experimental/filament/filament/scene_view.h @@ -31,10 +31,12 @@ #include #include #include +#include #include "experimental/filament/filament/color_grading_options.h" #include "experimental/filament/filament/drawable.h" #include "experimental/filament/filament/light.h" #include "experimental/filament/filament/material.h" +#include "experimental/filament/filament/model_objects.h" #include "experimental/filament/filament/object_manager.h" #include "experimental/filament/filament/render_target_util.h" @@ -47,7 +49,7 @@ namespace mujoco { // different rendering modes (e.g. normal, depth, segmentation, etc.) class SceneView { public: - SceneView(filament::Engine* engine, ObjectManager* object_mgr); + SceneView(ObjectManager* object_mgr, const mjModel* model); ~SceneView(); // Updates all views to render into the given viewport. @@ -71,6 +73,10 @@ class SceneView { void Render(filament::Renderer* renderer, DrawMode draw_mode, filament::RenderTarget* target = nullptr); + void UploadMesh(const mjModel* model, int id); + void UploadTexture(const mjModel* model, int id); + void UploadHeightField(const mjModel* model, int id); + // Accessors. filament::Engine* GetEngine() const; filament::View* GetDefaultRenderView(); @@ -96,12 +102,12 @@ class SceneView { const filament::math::float3& pos) const; ObjectManager* object_mgr_ = nullptr; - filament::Engine* engine_ = nullptr; filament::Scene* scene_ = nullptr; filament::Camera* camera_ = nullptr; filament::ColorGrading* color_grading_ = nullptr; std::vector> lights_; std::vector> drawables_; + std::unique_ptr model_objects_; std::array views_; filament::math::mat4 clip_from_world_; ColorGradingOptions color_grading_options_; diff --git a/src/experimental/filament/filament/texture.cc b/src/experimental/filament/filament/texture.cc new file mode 100644 index 0000000000..1ce6277492 --- /dev/null +++ b/src/experimental/filament/filament/texture.cc @@ -0,0 +1,270 @@ +// Copyright 2025 DeepMind Technologies Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "experimental/filament/filament/texture.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace mujoco { + +static constexpr int kNumFacesPerCube = 6; + +static bool IsCompressed(const TextureConfig& config) { + return config.format == mjPIXEL_FORMAT_KTX; +} + +static bool IsCubeMap(const TextureConfig& config) { + return config.target == mjTEXTURE_CUBE || config.target == mjTEXTURE_SKYBOX; +} + +static int GetFaceHeight(const TextureConfig& config) { + int face_height = config.height; + if (config.width != config.height) { + if (config.width * kNumFacesPerCube != config.height) { + mju_error("Cube maps must contain 6 square images."); + } + face_height = config.height / kNumFacesPerCube; + } + if (config.width != face_height) { + mju_error("Cube map faces must be square."); + } + return face_height; +} + +static int GetNumChannels(const TextureConfig& config) { + switch (config.format) { + case mjPIXEL_FORMAT_R8: + return 1; + case mjPIXEL_FORMAT_RGB8: + return 3; + case mjPIXEL_FORMAT_RGBA8: + return 4; + default: + mju_error("Unsupported format: %d", (int)config.format); + return 0; + } +} + +static filament::Texture::Format GetTextureFormat(const TextureConfig& config) { + switch (config.format) { + case mjPIXEL_FORMAT_R8: + return filament::Texture::Format::R; + case mjPIXEL_FORMAT_RGB8: + return filament::Texture::Format::RGB; + case mjPIXEL_FORMAT_RGBA8: + return filament::Texture::Format::RGBA; + default: + mju_error("Unsupported format: %d", (int)config.format); + return filament::Texture::Format::UNUSED; + } +} + +static filament::Texture::InternalFormat GetTextureInternalFormat( + const TextureConfig& config) { + if (config.color_space == mjCOLORSPACE_SRGB) { + switch (config.format) { + case mjPIXEL_FORMAT_RGB8: + return filament::Texture::InternalFormat::SRGB8; + case mjPIXEL_FORMAT_RGBA8: + return filament::Texture::InternalFormat::SRGB8_A8; + default: + mju_error("Unsupported format: %d", (int)config.format); + return filament::Texture::InternalFormat::UNUSED; + } + } else { + switch (config.format) { + case mjPIXEL_FORMAT_R8: + return filament::Texture::InternalFormat::R8; + case mjPIXEL_FORMAT_RGB8: + return filament::Texture::InternalFormat::RGB8; + case mjPIXEL_FORMAT_RGBA8: + return filament::Texture::InternalFormat::RGBA8; + default: + mju_error("Unsupported format: %d", (int)config.format); + return filament::Texture::InternalFormat::UNUSED; + } + } +} + +void DefaultTextureData(TextureData* data) { + std::memset(data, 0, sizeof(TextureData)); +} + +void DefaultTextureConfig(TextureConfig* config) { + std::memset(config, 0, sizeof(TextureConfig)); +} + +Texture::Texture(filament::Engine* engine, const TextureConfig& config) + : engine_(engine), config_(config) { + if (IsCompressed(config_)) { + // We defer creation of compressed textures until Upload() is called. In + // the meantime, we don't really know anything about the texture (e.g. + // width, height, etc.). + return; + } + + filament::Texture::Builder builder; + builder.width(config_.width); + builder.height(config_.height); + builder.format(GetTextureInternalFormat(config_)); + + if (IsCubeMap(config_)) { + if (config_.format != mjPIXEL_FORMAT_RGB8) { + mju_error("Only support RGB cubemaps."); + return; + } + builder.height(GetFaceHeight(config_)); + builder.sampler(filament::Texture::Sampler::SAMPLER_CUBEMAP); + } else { + builder.sampler(filament::Texture::Sampler::SAMPLER_2D); + } + + if (config_.color_space != mjCOLORSPACE_SRGB) { + builder.usage(filament::Texture::Usage::GEN_MIPMAPPABLE | + filament::Texture::Usage::SAMPLEABLE | + filament::Texture::Usage::UPLOADABLE); + } + texture_ = builder.build(*engine_); +} + +Texture::Texture(filament::Engine* engine, RenderTargetTextureType type, + int width, int height) : engine_(engine) { + filament::Texture::Builder builder; + builder.width(width); + builder.height(height); + switch (type) { + case RenderTargetTextureType::kColor: + builder.usage(filament::Texture::Usage::COLOR_ATTACHMENT | + filament::Texture::Usage::BLIT_SRC); + builder.format(filament::Texture::InternalFormat::RGB8); + break; + case RenderTargetTextureType::kDepth: + builder.usage(filament::Texture::Usage::DEPTH_ATTACHMENT | + filament::Texture::Usage::SAMPLEABLE); + builder.format(filament::Texture::InternalFormat::DEPTH32F); + break; + case RenderTargetTextureType::kDepthColor: + builder.usage(filament::Texture::Usage::COLOR_ATTACHMENT | + filament::Texture::Usage::BLIT_SRC); + builder.format(filament::Texture::InternalFormat::R32F); + break; + case RenderTargetTextureType::kReflectionColor: + builder.usage(filament::Texture::Usage::COLOR_ATTACHMENT | + filament::Texture::Usage::BLIT_SRC | + filament::Texture::Usage::SAMPLEABLE); + builder.format(filament::Texture::InternalFormat::RGBA8); + break; + default: + mju_error("Unknown type: %d", static_cast(type)); + } + texture_ = builder.build(*engine); +} + +Texture::~Texture() { + ReleaseData(); + if (texture_) { + engine_->destroy(texture_); + } +} + +void Texture::Upload(const TextureData& data) { + user_data_ = data.user_data; + release_callback_ = data.release_callback; + + if (data.bytes == nullptr || data.nbytes == 0) { + ReleaseData(); + return; + } + + if (config_.format == mjPIXEL_FORMAT_KTX) { + image::Ktx1Bundle* bundle = new image::Ktx1Bundle( + reinterpret_cast(data.bytes), data.nbytes); + has_spherical_harmonics_ = true; + bundle->getSphericalHarmonics(spherical_harmonics_); + const bool is_srgb = false; + texture_ = ktxreader::Ktx1Reader::createTexture(engine_, bundle, is_srgb); + config_.width = texture_->getWidth(); + config_.height = texture_->getHeight(); + ReleaseData(); + return; + } + + const int num_channels = GetNumChannels(config_); + const filament::Texture::Type type = filament::Texture::Type::UBYTE; + const filament::Texture::Format format = GetTextureFormat(config_); + + if (!IsCubeMap(config_)) { + if (config_.width * config_.height * num_channels != data.nbytes) { + mju_error("Texture size does not match data size."); + } + + auto callback = +[](void* buffer, size_t size, void* user) { + reinterpret_cast(user)->ReleaseData(); + }; + filament::Texture::PixelBufferDescriptor desc(data.bytes, data.nbytes, + format, type, callback, this); + texture_->setImage(*engine_, 0, std::move(desc)); + } else { + const int face_size = config_.width * GetFaceHeight(config_) * num_channels; + const int num_bytes = face_size * kNumFacesPerCube; + filament::Texture::FaceOffsets offsets(face_size); + + if (config_.width == config_.height) { + uint8_t* copy = new uint8_t[num_bytes]; + auto release_callback = +[](void* buffer, size_t size, void* user) { + delete [] reinterpret_cast(buffer); + }; + for (int i = 0; i < kNumFacesPerCube; ++i) { + std::memcpy(copy + (i * face_size), data.bytes, face_size); + } + filament::Texture::PixelBufferDescriptor desc(copy, num_bytes, format, + type, release_callback); + texture_->setImage(*engine_, 0, std::move(desc), offsets); + ReleaseData(); + } else { + if (num_bytes != data.nbytes) { + mju_error("Texture size does not match data size."); + } + auto callback = +[](void* buffer, size_t size, void* user) { + reinterpret_cast(user)->ReleaseData(); + }; + filament::Texture::PixelBufferDescriptor desc( + data.bytes, data.nbytes, format, type, callback, this); + texture_->setImage(*engine_, 0, std::move(desc), offsets); + } + } + + if (config_.color_space != mjCOLORSPACE_SRGB) { + texture_->generateMipmaps(*engine_); + } +} + +void Texture::ReleaseData() { + if (release_callback_) { + release_callback_(user_data_); + release_callback_ = nullptr; + user_data_ = nullptr; + } +} + +} // namespace mujoco diff --git a/src/experimental/filament/filament/texture.h b/src/experimental/filament/filament/texture.h new file mode 100644 index 0000000000..44b0ea3160 --- /dev/null +++ b/src/experimental/filament/filament/texture.h @@ -0,0 +1,145 @@ +// Copyright 2025 DeepMind Technologies Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_TEXTURE_H_ +#define MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_TEXTURE_H_ + +#include + +#include +#include +#include +#include + +// Functions for creating filament textures. +namespace mujoco { + +// The types of textures we can create. For internal use only. +enum class TextureTarget { + // A standard 2D image with a width and a height. + kNormal2d, + // A 2D texture split up into the 6 faces of a cube. + kCube, +}; + +// The different types of textures we can create for a render target. +// For internal use only. +enum class RenderTargetTextureType { + kColor, + kDepth, + kDepthColor, + kReflectionColor, +}; + +// Pixel formats for textures. +typedef enum mjtPixelFormat_ { + mjPIXEL_FORMAT_UNKNOWN = 0, + mjPIXEL_FORMAT_R8, + mjPIXEL_FORMAT_RGB8, + mjPIXEL_FORMAT_RGBA8, + mjPIXEL_FORMAT_DEPTH32F, + mjPIXEL_FORMAT_KTX, +} mjtPixelFormat; + +// The binary contents of a texture. +struct TextureData { + // Pointer to the image data. If null, an empty texture will be created. + void* bytes; + + // The number of bytes in the image data. + size_t nbytes; + + // Because rendering may be multithreaded, we cannot make assumptions about + // when the image data will finish uploading to the GPU. As such, we will use + // this callback to notify callers when it is safe to free the image data. + void (*release_callback)(void* user_data); + + // User data to pass to the release callback. + void* user_data; +}; + +// Initializes the TextureData to default values. +void DefaultTextureData(TextureData* data); + +// Defines the basic properties of a texture. +struct TextureConfig { + // The width of the texture. For compressed textures (e.g. KTX), this is the + // number of bytes in the compressed data. + int width; + + // The height of the texture. For compressed textures (e.g. KTX), this should + // be 0. + int height; + + // The target of the texture (e.g. 2D, cube, etc.) + mjtTexture target; + + // The format of the pixels in the texture (e.g. RGB8, RGBA8, KTX, etc.) + mjtPixelFormat format; + + // The color space of the texture (e.g. LINEAR, sRGB, etc.) + mjtColorSpace color_space; +}; + +// Initializes the TextureConfig to default values. +void DefaultTextureConfig(TextureConfig* config); + +// Wrapper around a filament::Texture. +class Texture { + public: + // Creates a texture with the given data. + Texture(filament::Engine* engine, const TextureConfig& config); + + // Creates a texture for use with a render target, for internal use. + Texture(filament::Engine* engine, RenderTargetTextureType type, int width, + int height); + + ~Texture(); + + // Uploads the given data to the texture. + void Upload(const TextureData& data); + + // Returns the width of the texture. + int GetWidth() const { return config_.width; } + + // Returns the height of the texture. + int GetHeight() const { return config_.height; } + + // Returns the underlying filament texture. + filament::Texture* GetFilamentTexture() const { return texture_; } + + // Returns any spherical harmonics data associated with the texture. + using SphericalHarmonics = filament::math::float3[9]; + const SphericalHarmonics* GetSphericalHarmonics() const { + return has_spherical_harmonics_ ? &spherical_harmonics_ : nullptr; + } + + Texture(const Texture&) = delete; + Texture& operator=(const Texture&) = delete; + + private: + void ReleaseData(); + + filament::Engine* engine_ = nullptr; + filament::Texture* texture_ = nullptr; + TextureConfig config_; + SphericalHarmonics spherical_harmonics_; + bool has_spherical_harmonics_ = false; + + void* user_data_ = nullptr; + void (*release_callback_)(void* user_data) = nullptr; +}; +} // namespace mujoco + +#endif // MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_TEXTURE_H_ diff --git a/src/experimental/filament/filament/texture_util.cc b/src/experimental/filament/filament/texture_util.cc deleted file mode 100644 index 1cf5567c1c..0000000000 --- a/src/experimental/filament/filament/texture_util.cc +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2025 DeepMind Technologies Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "experimental/filament/filament/texture_util.h" - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - - -namespace mujoco { - -static filament::Texture::Format GetTextureFormat(int num_channels) { - switch (num_channels) { - case 1: - return filament::Texture::Format::R; - case 3: - return filament::Texture::Format::RGB; - case 4: - return filament::Texture::Format::RGBA; - default: - mju_error("Unsupported number of channels: %d", num_channels); - return filament::Texture::Format::UNUSED; - } -} - -static filament::Texture::InternalFormat GetTextureInternalFormat( - int num_channels, bool is_srgb) { - if (is_srgb) { - switch (num_channels) { - case 3: - return filament::Texture::InternalFormat::SRGB8; - case 4: - return filament::Texture::InternalFormat::SRGB8_A8; - default: - mju_error("Unsupported number of channels: %d", num_channels); - return filament::Texture::InternalFormat::UNUSED; - } - } else { - switch (num_channels) { - case 1: - return filament::Texture::InternalFormat::R8; - case 3: - return filament::Texture::InternalFormat::RGB8; - case 4: - return filament::Texture::InternalFormat::RGBA8; - default: - mju_error("Unsupported number of channels: %d", num_channels); - return filament::Texture::InternalFormat::UNUSED; - } - } -} - -filament::Texture* Create2dTexture(filament::Engine* engine, int width, - int height, int num_channels, - const uint8_t* data, bool is_srgb) { - if (num_channels != 1 && num_channels != 3 && num_channels != 4) { - mju_error("Unsupported number of channels: %d", num_channels); - return nullptr; - } - - filament::Texture::Builder builder; - builder.width(width); - builder.height(height); - builder.format(GetTextureInternalFormat(num_channels, is_srgb)); - builder.sampler(filament::Texture::Sampler::SAMPLER_2D); - if (!is_srgb) { - builder.usage(filament::Texture::Usage::GEN_MIPMAPPABLE | - filament::Texture::Usage::SAMPLEABLE | - filament::Texture::Usage::UPLOADABLE); - } - filament::Texture* texture = builder.build(*engine); - - if (data) { - const size_t num_bytes = width * height * sizeof(uint8_t) * num_channels; - const filament::Texture::Format format = GetTextureFormat(num_channels); - texture->setImage( - *engine, 0, - filament::Texture::PixelBufferDescriptor( - data, num_bytes, format, filament::Texture::Type::UBYTE)); - if (!is_srgb) { - texture->generateMipmaps(*engine); - } - } - return texture; -} - -filament::Texture* CreateCubeTexture(filament::Engine* engine, int width, - int height, int num_channels, - const uint8_t* data, bool is_srgb) { - if (num_channels != 3) { - mju_error("Only support RGB cubemaps."); - return nullptr; - } - - const int kNumFacesPerCube = 6; - - int face_height = height; - if (width != height) { - if (width * kNumFacesPerCube != height) { - mju_error("Cube maps must contain 6 square images."); - } - face_height = height / kNumFacesPerCube; - } - if (width != face_height) { - mju_error("Cube map faces must be square."); - } - - filament::Texture::Builder builder; - builder.width(width); - builder.height(face_height); - builder.format(GetTextureInternalFormat(num_channels, is_srgb)); - builder.sampler(filament::Texture::Sampler::SAMPLER_CUBEMAP); - if (!is_srgb) { - builder.usage(filament::Texture::Usage::GEN_MIPMAPPABLE | - filament::Texture::Usage::SAMPLEABLE | - filament::Texture::Usage::UPLOADABLE); - } - filament::Texture* texture = builder.build(*engine); - - const int face_size = width * face_height * num_channels; - const int num_bytes = face_size * kNumFacesPerCube; - - uint8_t* buffer = new uint8_t[num_bytes]; - auto callback = +[](void* buffer, size_t size, void* user) { - delete [] reinterpret_cast(buffer); - }; - - filament::Texture::FaceOffsets offsets(face_size); - if (width == height) { - // Copy the image to all the faces. - for (int i = 0; i < kNumFacesPerCube; ++i) { - std::memcpy(buffer + (i * face_size), data, face_size); - } - } else { - // Use the cubemap as is. - std::memcpy(buffer, data, num_bytes); - } - - if (data) { - filament::Texture::PixelBufferDescriptor desc( - buffer, num_bytes, filament::Texture::Format::RGB, - filament::Texture::Type::UBYTE, callback); - texture->setImage(*engine, 0, std::move(desc), offsets); - if (!is_srgb) { - texture->generateMipmaps(*engine); - } - } - return texture; -} - -filament::Texture* CreateKtxTexture( - filament::Engine* engine, const uint8_t* data, int size, - filament::math::float3* spherical_harmonics_out) { - image::Ktx1Bundle* bundle = new image::Ktx1Bundle(data, size); - if (spherical_harmonics_out) { - bundle->getSphericalHarmonics(spherical_harmonics_out); - } - const bool is_srgb = false; - return ktxreader::Ktx1Reader::createTexture(engine, bundle, is_srgb); -} -} // namespace mujoco diff --git a/src/experimental/filament/filament/texture_util.h b/src/experimental/filament/filament/texture_util.h deleted file mode 100644 index 3136b81ce3..0000000000 --- a/src/experimental/filament/filament/texture_util.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2025 DeepMind Technologies Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_TEXTURE_UTIL_H_ -#define MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_TEXTURE_UTIL_H_ - -#include - -#include -#include -#include - -// Functions for creating filament textures. -namespace mujoco { - -// Creates a filament Texture for the given 2D texture. -filament::Texture* Create2dTexture(filament::Engine* engine, int width, - int height, int num_channels, - const uint8_t* data, bool is_srgb); - -// Creates a filament Texture for the given cube texture. -filament::Texture* CreateCubeTexture(filament::Engine* engine, int width, - int height, int num_channels, - const uint8_t* data, bool is_srgb); - -// Creates a filament Texture for the given KTX payload. -filament::Texture* CreateKtxTexture( - filament::Engine* engine, const uint8_t* data, int size, - filament::math::float3* spherical_harmonics_out); - -} // namespace mujoco - -#endif // MUJOCO_SRC_EXPERIMENTAL_FILAMENT_FILAMENT_TEXTURE_UTIL_H_ diff --git a/src/experimental/filament/render_context_filament.cc b/src/experimental/filament/render_context_filament.cc index 2e00bc7966..d49069f534 100644 --- a/src/experimental/filament/render_context_filament.cc +++ b/src/experimental/filament/render_context_filament.cc @@ -58,7 +58,7 @@ void mjrf_defaultContext(mjrContext* con) { } void mjrf_makeContext(const mjModel* m, mjrContext* con, int fontscale) { - mjr_freeContext(con); + mjrf_freeContext(con); mjrFilamentConfig cfg; mjrf_defaultFilamentConfig(&cfg); cfg.width = m->vis.global.offwidth; @@ -72,7 +72,7 @@ void mjrf_freeContext(mjrContext* con) { delete g_filament_context; g_filament_context = nullptr; } - mjr_defaultContext(con); + mjrf_defaultContext(con); } void mjrf_render(mjrRect viewport, mjvScene* scn, const mjrContext* con) { diff --git a/src/experimental/mjz/mjz_decoder.cc b/src/experimental/mjz/mjz_decoder.cc index 8defa31d96..ee29a80f6c 100644 --- a/src/experimental/mjz/mjz_decoder.cc +++ b/src/experimental/mjz/mjz_decoder.cc @@ -192,7 +192,7 @@ static mjSpec* ParseZipBuffer(const void* buffer, int nbuffer, const char* name, return mj_parseXML(root.c_str(), vfs, error, error_sz); } -mjPLUGIN_LIB_INIT { +mjPLUGIN_LIB_INIT(mjz_decoder) { mjpDecoder decoder; decoder.content_type = "application/zip"; decoder.extension = ".mjz|.zip"; diff --git a/src/experimental/platform/object_launcher_plugin.cc b/src/experimental/platform/object_launcher_plugin.cc index ad065bae58..4727aa7eba 100644 --- a/src/experimental/platform/object_launcher_plugin.cc +++ b/src/experimental/platform/object_launcher_plugin.cc @@ -172,7 +172,7 @@ class ObjectLauncher { } // namespace mujoco::studio -mjPLUGIN_LIB_INIT { +mjPLUGIN_LIB_INIT(object_launcher) { using mujoco::studio::ObjectLauncher; static ObjectLauncher plugin; diff --git a/src/experimental/platform/renderer.cc b/src/experimental/platform/renderer.cc index 42c77d5fa0..c34d0c0251 100644 --- a/src/experimental/platform/renderer.cc +++ b/src/experimental/platform/renderer.cc @@ -240,7 +240,7 @@ void Renderer::UpdateFps() { } // namespace mujoco::platform -mjPLUGIN_LIB_INIT { +mjPLUGIN_LIB_INIT(renderer) { mujoco::platform::GuiPlugin plugin; plugin.name = "Filament"; plugin.update = [](mujoco::platform::GuiPlugin* self) { diff --git a/src/user/user_api.cc b/src/user/user_api.cc index f5b569d9e1..2317cc3a0d 100644 --- a/src/user/user_api.cc +++ b/src/user/user_api.cc @@ -15,6 +15,7 @@ #include "user/user_api.h" #include +#include #include #include #include @@ -1120,6 +1121,166 @@ const char* mjs_setToAdhesion(mjsActuator* actuator, double gain) { +const char* mjs_setToDCMotor(mjsActuator* actuator, double motorconst[2], double resistance, + double nominal[3], double saturation[4], double inductance[2], + double cogging[3], double controller[5], double thermal[6], + double lugre[6], int input_mode) { + double Kt = motorconst[0]; // torque constant + double Ke = motorconst[1]; // back-EMF constant + double R = resistance; // electrical resistance + double vn = nominal[0]; // nominal voltage + double tau0 = nominal[1]; // stall torque + double omega0 = nominal[2]; // no-load speed + + // derive Ke from nominal: omega0 = vn*Ke / (Ke^2 + R*B) + if (vn > 0 && Ke <= 0 && omega0 > 0) { + // viscous damping (linear), add lugre sigma2 contribution if any + double B = actuator->damping[0]; + if (lugre[0] > 0) B += lugre[2]; + + if (B > 0 && R > 0) { + // R known: solve quadratic Ke^2*omega0 - Ke*vn + R*B*omega0 = 0 + double disc = vn*vn - 4*R*B*omega0*omega0; + Ke = disc > 0 ? (vn + sqrt(disc)) / (2*omega0) : vn / omega0; + } else if (B > 0 && tau0 > 0) { + // R from nominal (tau0 = Ke*vn/R, so R = Ke*vn/tau0) + // substituting into omega0 = vn*Ke/(Ke^2 + R*B): + // omega0 = vn/(Ke + vn*B/tau0) => Ke = vn/omega0 - vn*B/tau0 + double Ke_exact = vn / omega0 - vn*B / tau0; + Ke = Ke_exact > 0 ? Ke_exact : vn / omega0; + } else { + // B = 0 or insufficient data for B-correction: omega0 = vn*Ke/Ke^2 = vn/Ke + Ke = vn / omega0; + } + } + + // resolve effective motor constant K from [Kt, Ke] + double K = (Kt > 0 && Ke > 0) ? sqrt(Kt * Ke) : + (Kt > 0) ? Kt : Ke; + + // derive R from nominal: tau0 = K*vn/R + if (R == 0 && vn > 0 && tau0 > 0 && K > 0) { + R = K * vn / tau0; + } + + if (K <= 0) return "DC motor: motor constant K must be positive"; + if (R <= 0) return "DC motor: resistance R must be positive"; + + // set types + actuator->dyntype = mjDYN_DCMOTOR; + actuator->gaintype = mjGAIN_DCMOTOR; + actuator->biastype = mjBIAS_DCMOTOR; + + // gainprm: [R, K, alpha, T0] + actuator->gainprm[0] = R; + actuator->gainprm[1] = K; + + // controller parameters: gainprm[4:6] for kp, ki, kd + actuator->gainprm[4] = controller[0]; // kp + actuator->gainprm[5] = controller[1]; // ki + actuator->gainprm[6] = controller[2]; // kd + + // controller parameters: dynprm[7,8] for slewmax, Imax + actuator->dynprm[7] = controller[3]; // slewmax + actuator->dynprm[8] = controller[4]; // Imax + + // saturation: [tau_max, i_max, (di/dt)_max, v_max] + if (saturation[2] > 0) { + actuator->dynprm[1] = saturation[2]; // (di/dt)_max + } + if (saturation[3] > 0) { + actuator->gainprm[7] = saturation[3]; // v_max + } + + // saturation -> forcerange + if (saturation[0] > 0 || saturation[1] > 0) { + double tau_max = saturation[0]; + if (tau_max == 0 && saturation[1] > 0) { + tau_max = K * saturation[1]; // tau_max = K * i_max + } + actuator->forcerange[0] = -tau_max; + actuator->forcerange[1] = tau_max; + actuator->forcelimited = 1; + } + + // cogging: [amplitude, periodicity, phase] -> biasprm[0:3] + actuator->biasprm[0] = cogging[0]; // amplitude + actuator->biasprm[1] = cogging[1]; // periodicity + actuator->biasprm[2] = cogging[2]; // phase + + // count activation variables: slot order is slew, integral, temperature, bristle, current + int actdim = 0; + + // inductance: [L, te] + if (inductance[0] < 0) return "DC motor: inductance must be non-negative"; + if (inductance[1] < 0) return "DC motor: electrical time constant must be non-negative"; + double te = inductance[0] > 0 ? inductance[0] / R : inductance[1]; + actuator->dynprm[0] = te; + if (te > 0) { + actdim++; + } + + // controller states: slew rate limiting + if (controller[3] > 0) { // slewmax + actdim++; + } + + // controller states: integral + if (controller[1] > 0) { // ki + actdim++; + } + + // thermal -> temperature activation + if (thermal[0] > 0 || thermal[1] > 0 || thermal[2] > 0) { + double RT = thermal[0]; // thermal resistance + double C = thermal[1]; // thermal capacitance + double tth = thermal[2]; // thermal time constant + double alpha = thermal[3]; // temperature coefficient + double T0 = thermal[4]; // reference temperature + double Ta = thermal[5]; // ambient temperature + + if (tth > 0 && RT > 0 && C == 0) { + C = tth / RT; + } else if (tth > 0 && C > 0 && RT == 0) { + RT = tth / C; + } else if (tth == 0 && RT > 0 && C > 0) { + tth = RT * C; + } + + if (RT <= 0) return "DC motor: thermal resistance must be positive"; + if (C <= 0) return "DC motor: thermal capacitance must be positive"; + + actuator->dynprm[2] = RT; + actuator->dynprm[3] = C; + actuator->dynprm[4] = Ta; + actuator->gainprm[2] = alpha; + actuator->gainprm[3] = T0; + actdim++; + } + + // lugre: {stiffness, damping, viscous, coulomb, static, stribeck} + if (lugre[0] > 0) { + actuator->dynprm[5] = lugre[0]; // stiffness -> sigma0 + actuator->dynprm[6] = lugre[1]; // damping -> sigma1 + actuator->damping[0] += lugre[2]; // viscous -> sigma2 + actuator->biasprm[3] = lugre[3]; // coulomb -> tau_c + actuator->biasprm[4] = lugre[4]; // static -> tau_s + actuator->biasprm[5] = lugre[5]; // stribeck -> omega_s + actdim++; + } + + // set input mode and activation dimension + actuator->gainprm[8] = input_mode; + actuator->actdim = actdim; + + // enforce actlimited = 0; homogeneous bounds are invalid across DC motor states + actuator->actlimited = 0; + + return ""; +} + + + // get spec from body mjSpec* mjs_getSpec(mjsElement* element) { return &(static_cast(element)->model->spec); diff --git a/src/user/user_objects.cc b/src/user/user_objects.cc index c83951571b..5198cc20d0 100644 --- a/src/user/user_objects.cc +++ b/src/user/user_objects.cc @@ -7222,20 +7222,20 @@ void mjCActuator::Compile(void) { // check and set actdim if (!plugin.active) { - if (actdim > 1 && dyntype != mjDYN_USER) { - throw mjCError(this, "actdim > 1 is only allowed for dyntype 'user' in actuator"); + if (actdim > 1 && dyntype != mjDYN_USER && dyntype != mjDYN_DCMOTOR) { + throw mjCError(this, "actdim > 1 is only allowed for dyntype 'user' and 'dcmotor'"); } if (actdim == 1 && dyntype == mjDYN_NONE) { throw mjCError(this, "invalid actdim 1 in stateless actuator"); } - if (actdim == 0 && dyntype != mjDYN_NONE) { + if (actdim == 0 && dyntype != mjDYN_NONE && dyntype != mjDYN_DCMOTOR) { throw mjCError(this, "invalid actdim 0 in stateful actuator"); } } - // set actdim + // set actdim to 1 if it is unset and type is standard one-activation dyntype if (actdim < 0) { - actdim = (dyntype != mjDYN_NONE); + actdim = (dyntype != mjDYN_NONE && dyntype != mjDYN_DCMOTOR); } // check muscle parameters diff --git a/src/xml/xml_native_reader.cc b/src/xml/xml_native_reader.cc index d40e6a9adb..352790b5dc 100644 --- a/src/xml/xml_native_reader.cc +++ b/src/xml/xml_native_reader.cc @@ -206,6 +206,10 @@ std::vector MJCF[nMJCF] = { "lmin", "lmax", "vmax", "fpmax", "fvmax"}, {"adhesion", "?", "forcelimited", "ctrlrange", "forcerange", "gain", "user", "group", "nsample", "interp", "delay"}, + {"dcmotor", "?", "ctrllimited", "ctrlrange", + "gear", "damping", "armature", "cranklength", "user", "group", "nsample", "interp", "delay", + "motorconst", "resistance", "nominal", "saturation", + "inductance", "cogging", "controller", "input", "thermal", "lugre"}, {">"}, {"extension", "*"}, @@ -436,6 +440,12 @@ std::vector MJCF[nMJCF] = { "lmin", "lmax", "vmax", "fpmax", "fvmax"}, {"adhesion", "*", "name", "class", "group", "nsample", "interp", "delay", "forcelimited", "ctrlrange", "forcerange", "user", "body", "gain"}, + {"dcmotor", "*", "name", "class", "group", "nsample", "interp", "delay", + "ctrllimited", "ctrlrange", + "lengthrange", "gear", "damping", "armature", "cranklength", "user", + "joint", "jointinparent", "tendon", "slidersite", "cranksite", "site", "refsite", + "motorconst", "resistance", "nominal", "saturation", + "inductance", "cogging", "controller", "thermal", "lugre", "input"}, {"plugin", "*", "name", "class", "plugin", "instance", "group", "nsample", "interp", "delay", "ctrllimited", "forcelimited", "actlimited", "ctrlrange", "forcerange", "actrange", "lengthrange", "gear", "damping", "armature", "cranklength", "joint", "jointinparent", @@ -724,33 +734,45 @@ const mjMap mark_map[mark_sz] = { // dyn type -const int dyn_sz = 6; +const int dyn_sz = 7; const mjMap dyn_map[dyn_sz] = { {"none", mjDYN_NONE}, {"integrator", mjDYN_INTEGRATOR}, {"filter", mjDYN_FILTER}, {"filterexact", mjDYN_FILTEREXACT}, {"muscle", mjDYN_MUSCLE}, + {"dcmotor", mjDYN_DCMOTOR}, {"user", mjDYN_USER} }; +// dcmotor controller input mode +const int dcmotorinput_sz = 3; +const mjMap dcmotorinput_map[dcmotorinput_sz] = { + {"voltage", 0}, + {"position", 1}, + {"velocity", 2} +}; + + // gain type -const int gain_sz = 4; +const int gain_sz = 5; const mjMap gain_map[gain_sz] = { {"fixed", mjGAIN_FIXED}, {"affine", mjGAIN_AFFINE}, {"muscle", mjGAIN_MUSCLE}, + {"dcmotor", mjGAIN_DCMOTOR}, {"user", mjGAIN_USER} }; // bias type -const int bias_sz = 4; +const int bias_sz = 5; const mjMap bias_map[bias_sz] = { {"none", mjBIAS_NONE}, {"affine", mjBIAS_AFFINE}, {"muscle", mjBIAS_MUSCLE}, + {"dcmotor", mjBIAS_DCMOTOR}, {"user", mjBIAS_USER} }; @@ -2498,6 +2520,54 @@ void mjXReader::OneActuator(XMLElement* elem, mjsActuator* actuator) { err = mjs_setToAdhesion(actuator, gain); } + // DC motor + else if (type == "dcmotor") { + bool inherited = (actuator->gaintype == mjGAIN_DCMOTOR); + double motorconst[2] = {inherited ? actuator->gainprm[1] : 0, 0}; + double resistance = inherited ? actuator->gainprm[0] : 0; + double nominal[3] = {0, 0, 0}; + double saturation[4] = {0, 0, + inherited ? actuator->dynprm[1] : 0, + inherited ? actuator->gainprm[8] : 0}; + double controller[5] = {inherited ? actuator->gainprm[5] : 0, + inherited ? actuator->gainprm[6] : 0, + inherited ? actuator->gainprm[7] : 0, + inherited ? actuator->dynprm[7] : 0, + inherited ? actuator->dynprm[8] : 0}; + double inductance[2] = {0, inherited ? actuator->dynprm[0] : 0}; + double cogging[3] = {inherited ? actuator->biasprm[0] : 0, + inherited ? actuator->biasprm[1] : 0, + inherited ? actuator->biasprm[2] : 0}; + double thermal[6] = {inherited ? actuator->dynprm[2] : 0, + inherited ? actuator->dynprm[3] : 0, + 0, + inherited ? actuator->gainprm[2] : 0, + inherited ? actuator->gainprm[3] : 0, + inherited ? actuator->dynprm[4] : 0}; + double lugre[6] = {inherited ? actuator->dynprm[5] : 0, + inherited ? actuator->dynprm[6] : 0, + inherited ? actuator->damping[0] : 0, + inherited ? actuator->biasprm[3] : 0, + inherited ? actuator->biasprm[4] : 0, + inherited ? actuator->biasprm[5] : 0}; + int input_mode = inherited ? (int)actuator->gainprm[9] : 0; + ReadAttr(elem, "motorconst", 2, motorconst, text, false, false); + ReadAttr(elem, "resistance", 1, &resistance, text); + ReadAttr(elem, "nominal", 3, nominal, text, false, false); + ReadAttr(elem, "saturation", 4, saturation, text, false, false); + ReadAttr(elem, "inductance", 2, inductance, text, false, false); + ReadAttr(elem, "cogging", 3, cogging, text, false, false); + ReadAttr(elem, "controller", 5, controller, text, false, false); + ReadAttr(elem, "thermal", 6, thermal, text, false, false); + ReadAttr(elem, "lugre", 6, lugre, text, false, false); + if (MapValue(elem, "input", &input_mode, dcmotorinput_map, dcmotorinput_sz)) { + // successfully parsed + } + err = mjs_setToDCMotor(actuator, motorconst, resistance, + nominal, saturation, inductance, + cogging, controller, thermal, lugre, input_mode); + } + else if (type == "plugin") { OnePlugin(elem, &actuator->plugin); int n; @@ -2962,7 +3032,8 @@ void mjXReader::Default(XMLElement* section, const mjsDefault* def, const mjVFS* name == "intvelocity" || name == "cylinder" || name == "muscle" || - name == "adhesion") { + name == "adhesion" || + name == "dcmotor") { OneActuator(elem, def->actuator); } diff --git a/src/xml/xml_native_reader.h b/src/xml/xml_native_reader.h index 4e0369ce6a..8c568b22dc 100644 --- a/src/xml/xml_native_reader.h +++ b/src/xml/xml_native_reader.h @@ -102,7 +102,7 @@ class mjXReader : public mjXBase { }; // MJCF schema -#define nMJCF 246 +#define nMJCF 248 extern std::vector MJCF[nMJCF]; #endif // MUJOCO_SRC_XML_XML_NATIVE_READER_H_ diff --git a/src/xml/xml_native_writer.cc b/src/xml/xml_native_writer.cc index 144ffe6521..e9aa9ea7b3 100644 --- a/src/xml/xml_native_writer.cc +++ b/src/xml/xml_native_writer.cc @@ -871,7 +871,7 @@ void mjXWriter::OneActuator(XMLElement* elem, const mjCActuator* actuator, mjCDe if (writingdefaults) { WriteAttrInt(elem, "actdim", actuator->actdim, def->Actuator().actdim); } else { - int default_actdim = actuator->dyntype == mjDYN_NONE ? 0 : 1; + int default_actdim = (actuator->dyntype != mjDYN_NONE && actuator->dyntype != mjDYN_DCMOTOR); WriteAttrInt(elem, "actdim", actuator->actdim, default_actdim); } WriteAttrKey(elem, "dyntype", dyn_map, dyn_sz, actuator->dyntype, def->Actuator().dyntype); diff --git a/test/engine/engine_derivative_test.cc b/test/engine/engine_derivative_test.cc index df8f863bbd..c32e6b1e34 100644 --- a/test/engine/engine_derivative_test.cc +++ b/test/engine/engine_derivative_test.cc @@ -91,6 +91,8 @@ static const char* const kDampedPendulumPath = "engine/testdata/derivative/damped_pendulum.xml"; static const char* const kLinearPath = "engine/testdata/derivative/linear.xml"; +static const char* const kDCMotorPath = + "engine/testdata/derivative/dcmotor.xml"; static const char* const kModelPath = "testdata/model.xml"; // compare analytic and finite-difference d_smooth/d_qvel @@ -99,9 +101,12 @@ TEST_F(DerivativeTest, SmoothDvel) { for (const char* local_path : {kEnergyConservingPendulumPath, kTumblingThinObjectPath, kDampedActuatorsPath, - kDamperActuatorsPath}) { + kDamperActuatorsPath, + kDCMotorPath}) { const std::string xml_path = GetTestDataFilePath(local_path); - mjModel* model = mj_loadXML(xml_path.c_str(), nullptr, nullptr, 0); + char error[1024] = ""; + mjModel* model = mj_loadXML(xml_path.c_str(), nullptr, error, sizeof(error)); + ASSERT_THAT(model, testing::NotNull()) << "Failed to load model: " << error; int nD = model->nD; mjData* data = mj_makeData(model); @@ -758,9 +763,12 @@ TEST_F(DerivativeTest, DenseSparseRneEquivalent) { for (const char* local_path : {kEnergyConservingPendulumPath, kTumblingThinObjectPath, kDampedActuatorsPath, - kDamperActuatorsPath}) { + kDamperActuatorsPath, + kDCMotorPath}) { const std::string xml_path = GetTestDataFilePath(local_path); - mjModel* model = mj_loadXML(xml_path.c_str(), nullptr, nullptr, 0); + char error[1024] = ""; + mjModel* model = mj_loadXML(xml_path.c_str(), nullptr, error, sizeof(error)); + ASSERT_THAT(model, testing::NotNull()) << "Failed to load model: " << error; int nD = model->nD; mjtNum* qDeriv = (mjtNum*) mju_malloc(sizeof(mjtNum)*nD); mjData* data = mj_makeData(model); diff --git a/test/engine/engine_forward_test.cc b/test/engine/engine_forward_test.cc index cee67e297f..f3acf0aa59 100644 --- a/test/engine/engine_forward_test.cc +++ b/test/engine/engine_forward_test.cc @@ -1148,6 +1148,949 @@ TEST_F(ActuatorTest, DampRatioTendon) { mj_deleteModel(model); } +// ----------------------- DC motor actuators ---------------------------------- + +using DCMotorTest = MujocoTest; + +TEST_F(DCMotorTest, IntVelocityEquivalence) { + static constexpr char xml[] = R"( + + + + + + + + + + + + + + + + + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + // Apply a time-varying velocity command + while (data->time < 1.0) { + data->ctrl[0] = mju_sin(20 * data->time); + data->ctrl[1] = mju_sin(20 * data->time); + mj_step(model, data); + + // Both actuators should integrate identical states + EXPECT_MJTNUM_EQ(data->act[0], data->act[1]); + + // Both bodies should move identically + EXPECT_NEAR(data->qpos[0], data->qpos[1], MjTol(1e-14, 1e-7)); + EXPECT_NEAR(data->qvel[0], data->qvel[1], MjTol(1e-14, 1e-7)); + EXPECT_NEAR(data->qacc[0], data->qacc[1], MjTol(1e-14, 1e-6)); + + // Both actuators should produce identical force + EXPECT_NEAR(data->actuator_force[0], data->actuator_force[1], + MjTol(1e-14, 1e-6)); + } + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, StatelessSteadyState) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + double K = 0.05; + double R = 2.0; + double V = 12.0; + double omega = 3.0; + + data->ctrl[0] = V; + data->qvel[0] = omega; + mj_forward(model, data); + + double expected_force = K / R * (V - K * omega); + EXPECT_NEAR(data->actuator_force[0], expected_force, MjTol(1e-12, 1e-5)); + EXPECT_EQ(model->actuator_actnum[0], 0); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, CurrentFilterConverges) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + ASSERT_EQ(model->actuator_actnum[0], 1); + + double K = 0.05; + double R = 2.0; + double V = 12.0; + + data->ctrl[0] = V; + for (int i = 0; i < 10000; i++) { + mj_step(model, data); + } + + double omega = data->qvel[0]; + double i_ss = V / R - K / R * omega; + double expected_force = K * i_ss; + + EXPECT_NEAR(data->act[0], i_ss, MjTol(1e-6, 1e-4)); + EXPECT_NEAR(data->actuator_force[0], expected_force, MjTol(1e-6, 1e-4)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, CurrentFilterExactIntegration) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + double R = 2.0; + double te = 0.01 / R; + double V = 12.0; + + data->ctrl[0] = V; + mj_step(model, data); + + double h = model->opt.timestep; + double exact_current = V / R * (1 - mju_exp(-h / te)); + EXPECT_NEAR(data->act[0], exact_current, MjTol(1e-10, 1e-4)); + + double euler_current = V / R * h / te; + EXPECT_GT(std::abs(data->act[0] - euler_current), + std::abs(data->act[0] - exact_current)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, CoggingTorque) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + double A = 0.1, Np = 6, phi = 0; + double K = 0.05, R = 2.0; + double V = 5.0; + double pos = 1.0; + + data->ctrl[0] = V; + data->qpos[0] = pos; + mj_forward(model, data); + + double electrical_force = K / R * V; + double cogging = A * mju_sin(Np * pos + phi); + EXPECT_NEAR(data->actuator_force[0], electrical_force + cogging, + MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, CoggingBypassesSaturation) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + double A = 0.1, Np = 6, phi = 0; + double pos = 1.0; + + data->ctrl[0] = 100.0; + data->qpos[0] = pos; + mj_forward(model, data); + + double cogging = A * mju_sin(Np * pos + phi); + EXPECT_NEAR(model->actuator_forcerange[1], 0.001, MjTol(1e-12, 1e-5)); + EXPECT_GT(mju_abs(data->actuator_force[0]), 0.001); + EXPECT_NEAR(data->actuator_force[0], 0.001 + cogging, MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, LuGreViscousFriction) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + ASSERT_EQ(model->actuator_actnum[0], 1); + + double sigma1 = 1, sigma2 = 0.01; + double K = 0.05, R = 2.0; + double omega = 2.0; + + data->ctrl[0] = 0; + data->qvel[0] = omega; + mj_forward(model, data); + + EXPECT_MJTNUM_EQ(model->actuator_damping[0], sigma2); + double electrical_force = K / R * (0 - K * omega); + double z = data->act[model->actuator_actadr[0]]; + double z_dot = data->act_dot[model->actuator_actadr[0]]; + double lugre_force = 100 * z + sigma1 * z_dot; + EXPECT_NEAR(data->actuator_force[0], electrical_force - lugre_force, + MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, ThermalRiseAndFall) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + int adr = model->actuator_actadr[0]; + ASSERT_EQ(model->actuator_actnum[0], 1); + EXPECT_EQ(data->act[adr], 0); + + double R = 2.0, V = 10.0; + double RT = 10.0, C = 5.0; + double h = model->opt.timestep; + double P = V * V / R; + + data->ctrl[0] = V; + + mj_step(model, data); + double dT1 = h * P / C; + EXPECT_NEAR(data->act[adr], dT1, MjTol(1e-11, 1e-4)); + + mj_step(model, data); + double dT2 = dT1 + h * (P - dT1 / RT) / C; + EXPECT_NEAR(data->act[adr], dT2, MjTol(1e-11, 1e-4)); + + data->ctrl[0] = 0; + mj_step(model, data); + double dT3 = dT2 + h * (0 - dT2 / RT) / C; + EXPECT_NEAR(data->act[adr], dT3, MjTol(1e-11, 1e-4)); + EXPECT_LT(data->act[adr], dT2); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, ThermalSteadyState) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + double R = 2.0, V = 10.0; + double RT = 0.1; + double dT_ss = RT * V * V / R; + + data->ctrl[0] = V; + for (int i = 0; i < 10000; i++) { + mj_step(model, data); + } + + int adr = model->actuator_actadr[0]; + EXPECT_NEAR(data->act[adr], dT_ss, 1e-4); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, ThermalAffectsForce) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + double K = 0.05, R = 2.0, V = 10.0; + double alpha = 0.004; + int adr = model->actuator_actadr[0]; + + data->ctrl[0] = V; + data->act[adr] = 0; + mj_forward(model, data); + double force_cold = data->actuator_force[0]; + EXPECT_NEAR(force_cold, K / R * V, MjTol(1e-12, 1e-5)); + + double dT = 50; + data->act[adr] = dT; + mj_forward(model, data); + double R_hot = R * (1 + alpha * dT); + double force_hot = data->actuator_force[0]; + EXPECT_NEAR(force_hot, K / R_hot * V, MjTol(1e-12, 1e-5)); + EXPECT_LT(force_hot, force_cold); + + mj_deleteData(data); + mj_deleteModel(model); +} + +// Temperature slot must be correctly offset past slew and integral states. +TEST_F(DCMotorTest, ThermalAffectsForceWithController) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + // slot order: slew(0), integral(1), temperature(2) + ASSERT_EQ(model->actuator_actnum[0], 3); + int adr = model->actuator_actadr[0]; + int temp_adr = adr + 2; // temperature is slot 2 + + double K = 0.05, R = 2.0, alpha = 0.004; + double dT = 50; + data->act[adr] = 1.0; // slew state = ctrl: no rate-limiting applied + data->act[adr + 1] = 0.0; // integral state x_I = 0 + data->act[temp_adr] = dT; // temperature rise above ambient + data->ctrl[0] = 1.0; // position setpoint = 1.0, qpos = 0, error = 1.0 + mj_forward(model, data); + + // u_eff = ctrl = 1.0 (no slew applied since act[slew] == ctrl) + // V = kp*(u_eff - length) + ki*x_I - kd*omega = 1.0*1.0 + 1.0*0.0 - 0*0 = 1.0 + // R(T) = 2.0 * (1 + 0.004 * 50) = 2.4 + // stateless (no te): force = K/R(T) * V = 0.05/2.4 * 1.0 + double R_hot = R * (1 + alpha * dT); + EXPECT_NEAR(data->actuator_force[0], K / R_hot * 1.0, MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, StatelessPositionMode) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + // Position target 5.0, current pos 0.0, current vel 0.0 + data->ctrl[0] = 5.0; + mj_forward(model, data); + + // V = Kp * (u - theta) = 2.0 * 5.0 = 10.0 + // force = K / R * V + bias = (0.05 / 2.0) * 10.0 + 0 = 0.25 + EXPECT_NEAR(data->actuator_force[0], 0.25, MjTol(1e-12, 1e-5)); + + // Velocity penalty + data->qvel[0] = 2.0; + mj_forward(model, data); + // V = 10.0 - Kd * omega = 10.0 - (0.5 * 2.0) = 9.0 + // bias = - K^2 / R * omega = -0.0025 / 2.0 * 2.0 = -0.0025 + // force = K / R * V + bias = 0.225 - 0.0025 = 0.2225 + EXPECT_NEAR(data->actuator_force[0], 0.2225, MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, StatelessVelocityMode) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + // Velocity target 4.0, current vel 1.0 + data->ctrl[0] = 4.0; + data->qvel[0] = 1.0; + mj_forward(model, data); + + // V = Kp * (u - omega) = 3.0 * (4.0 - 1.0) = 9.0 + // bias = - K^2 / R * omega = -0.0025 / 2.0 * 1.0 = -0.00125 + // force = K / R * V + bias = (0.05 / 2.0) * 9.0 - 0.00125 = 0.22375 + EXPECT_NEAR(data->actuator_force[0], 0.22375, MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, StatefulPositionMode) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + // Controller states: 1 for slew, 1 for ki -> actnum = 2 + ASSERT_EQ(model->actuator_actnum[0], 2); + int adr = model->actuator_actadr[0]; + + // Current states + double u_prev = 1.0; + double x_I = 2.0; + data->act[adr] = u_prev; + data->act[adr+1] = x_I; + + // target 5.0 position, current 0.0 + data->ctrl[0] = 5.0; + data->qvel[0] = 0.5; + mj_forward(model, data); + + // slew bounding: s = 10.0, dt = 0.001. max_change = 0.01 + // Target = 5.0. It is upper bounded by u_prev + 0.01 = 1.01 + EXPECT_NEAR(data->act_dot[adr], 10.0, MjTol(1e-12, 1e-5)); + + // PI error: error = u_eff - length = 1.01 - 0.0 = 1.01 + EXPECT_NEAR(data->act_dot[adr+1], 1.01, MjTol(1e-12, 1e-5)); + + // V = Kp(u_eff - length) + Ki * x_I - Kd * omega + // V = 2.0 * 1.01 + 0.5 * 2.0 - 0.1 * 0.5 = 2.97 + // bias = - K^2/R * omega = -(0.05)^2 / 2.0 * 0.5 = -0.000625 + // force = K/R * V + bias = 0.025 * 2.97 - 0.000625 = 0.073625 + EXPECT_NEAR(data->actuator_force[0], 0.073625, MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, StatefulPositionWithCurrentMode) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + // Controller states: slew (0), ki (1), current (2). actnum = 3 + ASSERT_EQ(model->actuator_actnum[0], 3); + int adr = model->actuator_actadr[0]; + + double u_prev = 1.0; + double x_I = 2.0; + double current = 0.5; + data->act[adr] = u_prev; + data->act[adr+1] = x_I; + data->act[adr+2] = current; + + // Target 5.0 position, velocity 0.5 + data->ctrl[0] = 5.0; + data->qvel[0] = 0.5; + mj_forward(model, data); + + // Slew bounding: max_change = 0.01, u_eff = 1.01 + EXPECT_NEAR(data->act_dot[adr], 10.0, MjTol(1e-12, 1e-5)); + + // PI error: error = u_eff - length = 1.01 + EXPECT_NEAR(data->act_dot[adr+1], 1.01, MjTol(1e-12, 1e-5)); + + // Voltage computation: + // V = Kp(u_eff - length) + Ki * x_I - Kd * omega + // V = 2.0 * 1.01 + 0.5 * 2.0 - 0.1 * 0.5 = 2.97 + + // Current filter: + // t_e = L / R = 1.0 / 2.0 = 0.5 + // di/dt = (V/R - K/R * omega - i) / t_e + // di/dt = (2.97/2.0 - 0.05/2.0 * 0.5 - 0.5) / 0.5 + // di/dt = (1.485 - 0.0125 - 0.5) / 0.5 = 0.9725 / 0.5 = 1.945 + EXPECT_NEAR(data->act_dot[adr+2], 1.945, MjTol(1e-12, 1e-5)); + + // Force is just K * current since current is stateful + EXPECT_NEAR(data->actuator_force[0], 0.05 * 0.5, MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, StatefulVelocityMode) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + // Controller states: 1 for ki (no slew) + ASSERT_EQ(model->actuator_actnum[0], 1); + int adr = model->actuator_actadr[0]; + + double x_I = 2.0; // Exactly at Imax limit (Imax = 2.0) + data->act[adr] = x_I; + + // target vel 4.0, current vel 1.0 + data->ctrl[0] = 4.0; + data->qvel[0] = 1.0; + mj_forward(model, data); + + // integrate command directly: error = target = 4.0 + // since x_I == Imax (2.0) and error (4.0) > 0, act_dot should be clamped to 0 + EXPECT_NEAR(data->act_dot[adr], 0.0, MjTol(1e-12, 1e-5)); + + // V = Kp * (u_eff - omega) + Ki * (x_I - length) + // V = 3.0 * (4.0 - 1.0) + 1.0 * (2.0 - 0.0) = 9.0 + 2.0 = 11.0 + // bias = - K^2/R * omega = -(0.05)^2 / 2.0 * 1.0 = -0.00125 + // force = K/R * V + bias = 0.025 * 11.0 - 0.00125 = 0.275 - 0.00125 = 0.27375 + EXPECT_NEAR(data->actuator_force[0], 0.27375, MjTol(1e-12, 1e-5)); + + // repeat with non-zero joint position + data->qpos[0] = 1.5; + mj_forward(model, data); + + // V = 3.0 * (4.0 - 1.0) + 1.0 * (2.0 - 1.5) = 9.0 + 0.5 = 9.5 + // force = K/R * V + bias = 0.025 * 9.5 - 0.00125 = 0.2375 - 0.00125 = 0.23625 + EXPECT_NEAR(data->actuator_force[0], 0.23625, MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, CurrentPlusThermal) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + ASSERT_EQ(model->actuator_actnum[0], 2); + int adr = model->actuator_actadr[0]; + + double K = 0.05, R = 2.0, V = 12.0; + double te = 0.01 / R; + double RT = 10.0, C = 5.0; + + double current = 3.0; + double dT = 10.0; + data->act[adr] = dT; + data->act[adr+1] = current; + data->ctrl[0] = V; + mj_forward(model, data); + + EXPECT_NEAR(data->actuator_force[0], K * current, MjTol(1e-12, 1e-5)); + + double R_hot = R * (1 + 0.004 * dT); + double T_dot = (R_hot * current * current - dT / RT) / C; + EXPECT_NEAR(data->act_dot[adr], T_dot, MjTol(1e-10, 1e-4)); + + double omega = data->qvel[0]; + double i_dot = (V/R_hot - K/R_hot*omega - current) / te; + EXPECT_NEAR(data->act_dot[adr+1], i_dot, MjTol(1e-10, 1e-3)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, CurrentRateLimit) { + // Verifies that saturation:current_rate clamps di/dt. + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + ASSERT_EQ(model->actuator_actnum[0], 1); + int adr = model->actuator_actadr[0]; + + double V = 12.0; + double dimax = 100.0; // A/s rate limit + + // unclamped: i_dot = (V/R - 0 - 0) / te = 6 / 0.005 = 1200 A/s >> dimax + data->act[adr] = 0; // current = 0 + data->ctrl[0] = V; + mj_forward(model, data); + + // i_dot should be clipped to +dimax + EXPECT_NEAR(data->act_dot[adr], dimax, MjTol(1e-12, 1e-5)); + + // reverse: large negative drive + data->ctrl[0] = -V; + mj_forward(model, data); + + // i_dot should be clipped to -dimax + EXPECT_NEAR(data->act_dot[adr], -dimax, MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, LuGreExactIntegration) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + ASSERT_EQ(model->actuator_actnum[0], 1); + int adr = model->actuator_actadr[0]; + + double sigma0 = 100, F_C = 0.5, F_S = 0.7, v_S = 10; + double z0 = 0.002; + double v = 0.5; + double h = model->opt.timestep; + + data->act[adr] = z0; + data->qvel[0] = v; + + double ratio = v / v_S; + double g_v = F_C + (F_S - F_C) * mju_exp(-ratio*ratio); + double a = -sigma0 * std::abs(v) / g_v; + double exp_ah = mju_exp(a * h); + double int_h = (exp_ah - 1) / a; + double z_new = exp_ah * z0 + int_h * v; + + mj_step(model, data); + EXPECT_NEAR(data->act[adr], z_new, MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, LuGreSteadyState) { + static constexpr char xml[] = R"( + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + int adr = model->actuator_actadr[0]; + + double sigma0 = 100, sigma2 = 0.01; + double F_C = 0.5, F_S = 0.7, v_S = 10; + double K = 0.05, R = 2.0; + double v = 0.5; + + data->qvel[0] = v; + data->ctrl[0] = 0; + for (int i = 0; i < 10000; i++) { + mj_step(model, data); + } + + double ratio = v / v_S; + double g_v = F_C + (F_S - F_C) * mju_exp(-ratio*ratio); + double z_ss = g_v / sigma0; + EXPECT_NEAR(data->act[adr], z_ss, 1e-4); + + EXPECT_MJTNUM_EQ(model->actuator_damping[0], sigma2); + double back_emf = K * K / R * data->qvel[0]; + double lugre_ss = g_v; + EXPECT_NEAR(data->actuator_force[0], -back_emf - lugre_ss, 1e-3); + + mj_deleteData(data); + mj_deleteModel(model); +} + +TEST_F(DCMotorTest, LuGreBristleSpring) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + char error[1024]; + mjModel* model = LoadModelFromString(xml, error, sizeof(error)); + ASSERT_THAT(model, NotNull()) << error; + mjData* data = mj_makeData(model); + + int adr = model->actuator_actadr[0]; + double sigma0 = 100; + double X = 0.01; + + data->act[adr] = X; + data->ctrl[0] = 0; + mj_forward(model, data); + + EXPECT_NEAR(data->actuator_force[0], -sigma0 * X, MjTol(1e-12, 1e-5)); + + mj_deleteData(data); + mj_deleteModel(model); +} + // ----------------------- filterexact actuators ------------------------------- using FilterExactTest = MujocoTest; diff --git a/test/engine/testdata/derivative/dcmotor.xml b/test/engine/testdata/derivative/dcmotor.xml new file mode 100644 index 0000000000..d3c4b0acad --- /dev/null +++ b/test/engine/testdata/derivative/dcmotor.xml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/user/CMakeLists.txt b/test/user/CMakeLists.txt index 7cbc667f2c..912a26c0d4 100644 --- a/test/user/CMakeLists.txt +++ b/test/user/CMakeLists.txt @@ -14,9 +14,6 @@ mujoco_test( user_model_test - PROPERTIES - ENVIRONMENT - "MUJOCO_PLUGIN_DIR=$" ADDITIONAL_LINK_LIBRARIES absl::str_format ) @@ -31,16 +28,10 @@ mujoco_test( mujoco_test( user_flex_test - PROPERTIES - ENVIRONMENT - "MUJOCO_PLUGIN_DIR=$" ) mujoco_test( user_mesh_test - PROPERTIES - ENVIRONMENT - "MUJOCO_PLUGIN_DIR=$" ADDITIONAL_LINK_LIBRARIES absl::str_format ) diff --git a/test/xml/CMakeLists.txt b/test/xml/CMakeLists.txt index 63348c625b..d4cefdb1b3 100644 --- a/test/xml/CMakeLists.txt +++ b/test/xml/CMakeLists.txt @@ -16,9 +16,6 @@ mujoco_test(xml_api_test) mujoco_test( xml_native_reader_test - PROPERTIES - ENVIRONMENT - "MUJOCO_PLUGIN_DIR=$" ) mujoco_test(xml_utils_test) diff --git a/test/xml/xml_native_reader_test.cc b/test/xml/xml_native_reader_test.cc index 45ae2393de..bd1b517bc3 100644 --- a/test/xml/xml_native_reader_test.cc +++ b/test/xml/xml_native_reader_test.cc @@ -3003,6 +3003,325 @@ TEST_F(ActuatorParseTest, AdhesionInheritsFromGeneral) { mj_deleteModel(model); } +TEST_F(ActuatorParseTest, DCMotorBasicParsing) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, NotNull()) << error.data(); + EXPECT_EQ(model->actuator_dyntype[0], mjDYN_DCMOTOR); + EXPECT_EQ(model->actuator_gaintype[0], mjGAIN_DCMOTOR); + EXPECT_EQ(model->actuator_biastype[0], mjBIAS_DCMOTOR); + EXPECT_MJTNUM_EQ(model->actuator_gainprm[0], 2.0); + EXPECT_MJTNUM_EQ(model->actuator_gainprm[1], 0.05); + EXPECT_MJTNUM_EQ(model->actuator_damping[0], 1.0); + EXPECT_MJTNUM_EQ(model->actuator_dampingpoly[0], 2.0); + EXPECT_MJTNUM_EQ(model->actuator_dampingpoly[1], 3.0); + EXPECT_MJTNUM_EQ(model->actuator_armature[0], 0.1); + mj_deleteModel(model); +} + +TEST_F(ActuatorParseTest, DCMotorNominalDerivation) { + static constexpr char xml[] = R"( + + + + + + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, NotNull()) << error.data(); + + // actuator 0: B = 0, Ke = vn/omega0 + { + double K = 12.0 / 600.0; + double R = K * 12.0 / 0.6; + EXPECT_MJTNUM_EQ(model->actuator_gainprm[0*mjNGAIN + 0], R); + EXPECT_MJTNUM_EQ(model->actuator_gainprm[0*mjNGAIN + 1], K); + } + + // actuator 1: B > 0, R given, quadratic Ke^2*omega0 - Ke*vn + R*B*omega0 = 0 + { + double B = 0.0001, R = 0.4, vn = 12.0, omega0 = 600.0; + double disc = vn*vn - 4*R*B*omega0*omega0; + double Ke = (vn + sqrt(disc)) / (2*omega0); + EXPECT_MJTNUM_EQ(model->actuator_gainprm[1*mjNGAIN + 0], R); + EXPECT_MJTNUM_EQ(model->actuator_gainprm[1*mjNGAIN + 1], Ke); + } + + // actuator 2: B > 0, R from nominal, Ke = vn/omega0 - vn*B/tau0 + { + double B = 0.0001, vn = 12.0, tau0 = 0.6, omega0 = 600.0; + double Ke = vn / omega0 - vn*B / tau0; + double R = Ke * vn / tau0; + EXPECT_MJTNUM_EQ(model->actuator_gainprm[2*mjNGAIN + 0], R); + EXPECT_MJTNUM_EQ(model->actuator_gainprm[2*mjNGAIN + 1], Ke); + } + + mj_deleteModel(model); +} + +TEST_F(ActuatorParseTest, DCMotorSaturation) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, NotNull()) << error.data(); + EXPECT_EQ(model->actuator_forcelimited[0], 1); + EXPECT_MJTNUM_EQ(model->actuator_forcerange[0], -1.5); + EXPECT_MJTNUM_EQ(model->actuator_forcerange[1], 1.5); + mj_deleteModel(model); +} + +TEST_F(ActuatorParseTest, DCMotorLuGreRemapping) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, NotNull()) << error.data(); + EXPECT_MJTNUM_EQ(model->actuator_dynprm[5], 100); + EXPECT_MJTNUM_EQ(model->actuator_dynprm[6], 1); + EXPECT_MJTNUM_EQ(model->actuator_damping[0], 0.01); + EXPECT_MJTNUM_EQ(model->actuator_biasprm[3], 0.5); + EXPECT_MJTNUM_EQ(model->actuator_biasprm[4], 0.7); + EXPECT_MJTNUM_EQ(model->actuator_biasprm[5], 10); + mj_deleteModel(model); +} + +TEST_F(ActuatorParseTest, DCMotorActdimStateless) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, NotNull()) << error.data(); + EXPECT_EQ(model->actuator_actnum[0], 0); + EXPECT_EQ(model->actuator_actadr[0], -1); + mj_deleteModel(model); +} + +TEST_F(ActuatorParseTest, DCMotorActdimCurrentOnly) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, NotNull()) << error.data(); + EXPECT_EQ(model->actuator_actnum[0], 1); + EXPECT_MJTNUM_EQ(model->actuator_dynprm[0], 0.001 / 2.0); + mj_deleteModel(model); +} + +TEST_F(ActuatorParseTest, DCMotorActdimThermalOnly) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, NotNull()) << error.data(); + EXPECT_EQ(model->actuator_actnum[0], 1); + EXPECT_MJTNUM_EQ(model->actuator_dynprm[2], 10); + EXPECT_MJTNUM_EQ(model->actuator_dynprm[3], 5); + EXPECT_MJTNUM_EQ(model->actuator_dynprm[4], 25); + mj_deleteModel(model); +} + +TEST_F(ActuatorParseTest, DCMotorActdimLuGreOnly) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, NotNull()) << error.data(); + EXPECT_EQ(model->actuator_actnum[0], 1); + EXPECT_MJTNUM_EQ(model->actuator_dynprm[5], 100); + mj_deleteModel(model); +} + +TEST_F(ActuatorParseTest, DCMotorActdimAllThree) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, NotNull()) << error.data(); + EXPECT_EQ(model->actuator_actnum[0], 3); + mj_deleteModel(model); +} + +TEST_F(ActuatorParseTest, DCMotorMissingKError) { + static constexpr char xml[] = R"( + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, IsNull()); + EXPECT_THAT(error.data(), HasSubstr("motor constant K must be positive")); +} + +TEST_F(ActuatorParseTest, DCMotorDefaultsPropagate) { + static constexpr char xml[] = R"( + + + + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, NotNull()) << error.data(); + EXPECT_MJTNUM_EQ(model->actuator_gainprm[0], 1.5); + EXPECT_MJTNUM_EQ(model->actuator_gainprm[1], 0.03); + mj_deleteModel(model); +} + +TEST_F(ActuatorParseTest, DCMotorMotorconstGeometricMean) { + static constexpr char xml[] = R"( + + + + + + + + + + + + + )"; + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + ASSERT_THAT(model, NotNull()) << error.data(); + double K = std::sqrt(0.03 * 0.05); + EXPECT_MJTNUM_EQ(model->actuator_gainprm[0], 2.0); + EXPECT_MJTNUM_EQ(model->actuator_gainprm[1], K); + EXPECT_MJTNUM_EQ(model->actuator_gainprm[mjNGAIN + 1], 0.03); + mj_deleteModel(model); +} + TEST_F(ActuatorParseTest, ActdimDefaultsPropagate) { static constexpr char xml[] = R"( diff --git a/unity/Runtime/Bindings/MjBindings.cs b/unity/Runtime/Bindings/MjBindings.cs index ea8ae869c2..1e6168a9de 100644 --- a/unity/Runtime/Bindings/MjBindings.cs +++ b/unity/Runtime/Bindings/MjBindings.cs @@ -269,19 +269,22 @@ public enum mjtDyn : int{ mjDYN_FILTER = 2, mjDYN_FILTEREXACT = 3, mjDYN_MUSCLE = 4, - mjDYN_USER = 5, + mjDYN_DCMOTOR = 5, + mjDYN_USER = 6, } public enum mjtGain : int{ mjGAIN_FIXED = 0, mjGAIN_AFFINE = 1, mjGAIN_MUSCLE = 2, - mjGAIN_USER = 3, + mjGAIN_DCMOTOR = 3, + mjGAIN_USER = 4, } public enum mjtBias : int{ mjBIAS_NONE = 0, mjBIAS_AFFINE = 1, mjBIAS_MUSCLE = 2, - mjBIAS_USER = 3, + mjBIAS_DCMOTOR = 3, + mjBIAS_USER = 4, } public enum mjtObj : int{ mjOBJ_UNKNOWN = 0, diff --git a/wasm/CMakeLists.txt b/wasm/CMakeLists.txt index 909401d1e8..0800483955 100644 --- a/wasm/CMakeLists.txt +++ b/wasm/CMakeLists.txt @@ -67,6 +67,11 @@ set_target_properties(mujoco_wasm PROPERTIES OUTPUT_NAME "mujoco" ) -target_link_libraries(mujoco_wasm ccd lodepng mujoco tinyxml2 qhullstatic_r obj_decoder stl_decoder) +# Link the mujoco library as a whole archive to avoid losing plugin +# registration such as obj_decoder and stl_decoder. +target_link_libraries(mujoco_wasm PRIVATE + -Wl,--whole-archive mujoco -Wl,--no-whole-archive + ccd lodepng tinyxml2 qhullstatic_r +) install(TARGETS mujoco_wasm DESTINATION ${DIVISIBLE_INSTALL_BIN_DIR}) diff --git a/wasm/codegen/generated/bindings.cc b/wasm/codegen/generated/bindings.cc index 53013ccd7f..cfa83b4068 100644 --- a/wasm/codegen/generated/bindings.cc +++ b/wasm/codegen/generated/bindings.cc @@ -9876,6 +9876,18 @@ std::string mjs_setToCylinder_wrapper(MjsActuator& actuator, double timeconst, d return std::string(mjs_setToCylinder(actuator.get(), timeconst, bias, area, diameter)); } +std::string mjs_setToDCMotor_wrapper(MjsActuator& actuator, const val& motorconst, double resistance, const val& nominal, const val& saturation, const val& inductance, const val& cogging, const val& controller, const val& thermal, const val& lugre, int input_mode) { + UNPACK_VALUE(double, motorconst); + UNPACK_VALUE(double, nominal); + UNPACK_VALUE(double, saturation); + UNPACK_VALUE(double, inductance); + UNPACK_VALUE(double, cogging); + UNPACK_VALUE(double, controller); + UNPACK_VALUE(double, thermal); + UNPACK_VALUE(double, lugre); + return std::string(mjs_setToDCMotor(actuator.get(), motorconst_.data(), resistance, nominal_.data(), saturation_.data(), inductance_.data(), cogging_.data(), controller_.data(), thermal_.data(), lugre_.data(), input_mode)); +} + std::string mjs_setToDamper_wrapper(MjsActuator& actuator, double kv) { return std::string(mjs_setToDamper(actuator.get(), kv)); } @@ -10812,6 +10824,7 @@ EMSCRIPTEN_BINDINGS(mujoco_bindings) { .value("mjBIAS_NONE", mjBIAS_NONE) .value("mjBIAS_AFFINE", mjBIAS_AFFINE) .value("mjBIAS_MUSCLE", mjBIAS_MUSCLE) + .value("mjBIAS_DCMOTOR", mjBIAS_DCMOTOR) .value("mjBIAS_USER", mjBIAS_USER); enum_("mjtBuiltin") .value("mjBUILTIN_NONE", mjBUILTIN_NONE) @@ -10912,6 +10925,7 @@ EMSCRIPTEN_BINDINGS(mujoco_bindings) { .value("mjDYN_FILTER", mjDYN_FILTER) .value("mjDYN_FILTEREXACT", mjDYN_FILTEREXACT) .value("mjDYN_MUSCLE", mjDYN_MUSCLE) + .value("mjDYN_DCMOTOR", mjDYN_DCMOTOR) .value("mjDYN_USER", mjDYN_USER); enum_("mjtEnableBit") .value("mjENBL_OVERRIDE", mjENBL_OVERRIDE) @@ -10974,6 +10988,7 @@ EMSCRIPTEN_BINDINGS(mujoco_bindings) { .value("mjGAIN_FIXED", mjGAIN_FIXED) .value("mjGAIN_AFFINE", mjGAIN_AFFINE) .value("mjGAIN_MUSCLE", mjGAIN_MUSCLE) + .value("mjGAIN_DCMOTOR", mjGAIN_DCMOTOR) .value("mjGAIN_USER", mjGAIN_USER); enum_("mjtGeom") .value("mjGEOM_PLANE", mjGEOM_PLANE) @@ -13295,6 +13310,7 @@ EMSCRIPTEN_BINDINGS(mujoco_bindings) { function("mjs_setName", &mjs_setName_wrapper); function("mjs_setToAdhesion", &mjs_setToAdhesion_wrapper); function("mjs_setToCylinder", &mjs_setToCylinder_wrapper); + function("mjs_setToDCMotor", &mjs_setToDCMotor_wrapper); function("mjs_setToDamper", &mjs_setToDamper_wrapper); function("mjs_setToIntVelocity", &mjs_setToIntVelocity_wrapper); function("mjs_setToMotor", &mjs_setToMotor_wrapper); diff --git a/wasm/tests/CMakeLists.txt b/wasm/tests/CMakeLists.txt index 0a21ec31c8..f6d7428986 100644 --- a/wasm/tests/CMakeLists.txt +++ b/wasm/tests/CMakeLists.txt @@ -53,6 +53,11 @@ add_executable(mujoco_wasm_benchmark ${MUJOCO_WASM_FILES}) set_target_properties(mujoco_wasm_benchmark PROPERTIES LINK_FLAGS "${EMCC_LINKER_FLAGS_STR}") -target_link_libraries(mujoco_wasm_benchmark ccd lodepng mujoco tinyxml2 qhullstatic_r obj_decoder stl_decoder) +# Link the mujoco library as a whole archive to avoid losing plugin +# registration such as obj_decoder and stl_decoder. +target_link_libraries(mujoco_wasm_benchmark PRIVATE + -Wl,--whole-archive mujoco -Wl,--no-whole-archive + ccd lodepng tinyxml2 qhullstatic_r +) install(TARGETS mujoco_wasm_benchmark DESTINATION ${DIVISIBLE_INSTALL_BIN_DIR})