Skip to content

MJX: add differentiable distance calculation#3138

Open
aarushk09 wants to merge 1 commit intogoogle-deepmind:mainfrom
aarushk09:mjx-differentiable-distance
Open

MJX: add differentiable distance calculation#3138
aarushk09 wants to merge 1 commit intogoogle-deepmind:mainfrom
aarushk09:mjx-differentiable-distance

Conversation

@aarushk09
Copy link
Copy Markdown

Summary

Replace hard-coded dist=1 sentinel values in MJX collision functions with actual signed distance values. This preserves the gradient chain through JAX even when geoms are not colliding, enabling differentiable distance computation.

Changes

All changes are in mjx/mujoco/mjx/_src/collision_convex.py (17 insertions, 15 deletions):

  • _sphere_convex: Remove jp.where(has_separating_axis, 1.0, ...) — the expression d - sphere.size[0] is already correct for both cases
  • _capsule_convex: Use actual face/edge distances instead of -1 sentinels
  • _create_contact_manifold: Use jp.abs(penetration) for non-contact points instead of jp.ones_like
  • _box_box_impl / _sat_gaussmap: Use jp.abs(dist) for inactive edge contact slots and face-separating cases
  • plane_convex / hfield_*: Use jp.abs(support).max() / jp.max(dist) for non-unique contacts

Testing

  • All 37 existing collision_driver_test.py tests pass
  • Verified that jax.grad produces finite, non-zero gradients through contact.dist for non-colliding sphere-box pair

Fixes #3131

@google-cla
Copy link
Copy Markdown

google-cla bot commented Mar 1, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Replace hard-coded dist=1 sentinel values in collision functions with
actual signed distance values. This preserves the gradient chain through
JAX even when geoms are not colliding, enabling differentiable distance
computation for optimization and learning applications.

Fixes google-deepmind#3131
@ZC502
Copy link
Copy Markdown

ZC502 commented Mar 31, 2026

@aarushk09 @egordon ,
Implementing differentiable distance in MJX (PR #3138) is a significant milestone. However, as we move from analytic primitives to mesh-to-mesh gradients, there's a hidden bottleneck that often degrades RL agents and trajectory optimizers: Non-Associative Residuals (NAR).

The "Numerical-Only" Differentiability Issue:
Even when JAX provides a valid gradient for dist = d - size, we’ve observed that at high-frequency interaction scales, these gradients can be physically inconsistent.

In our audits using SIPA (Simulation Integrity & Physics Auditor)—developed based on the Non-Associative Residual Hypothesis (NARH)—we found that standard collision pipelines (like GJK/EPA adaptations) can suffer from "Algebraic Leakage" when resolving redundant contact manifolds. This manifests as "Gradient Noise," which prevents optimizers from converging to physically consistent solutions.The Discrete Associator $A(a,b,c;s) = ((\Psi_a \circ \Psi_b) \circ \Psi_c)(s) - (\Psi_a \circ (\Psi_b \circ \Psi_c))(s)$ serves as a "Physical Consistency Probe."

Observations from 7-DoF Benchmarks (e.g., KUKA iiwa):
When applying similar differentiable distance logic to high-precision industrial models, we noticed that the Associator—a measure of non-associative error—spikes precisely at the moment of the has_separating_axis toggle. This suggests that "differentiability" alone isn't enough for Sim-to-Real transfer; we need Algebraic Integrity.
z_jitter.png
“Attached is a SIPA audit of a similar 7-DoF trajectory. Note the high-frequency Z-jitter exceeding 10mm at the exact moments of constraint activation.”

A Potential Path Forward:
To achieve true Sim-to-Real fidelity, we suggest evaluating not just the magnitude of these gradients, but their Physical Consistency. Our PoC indicates that mapping distance calculations onto non-associative (Octonion-based) manifolds can eliminate these "numerical toxins" by naturally coupling temporal and spatial states, leading to smoother, more transferable gradients.

Is the MJX team interested in using SIPA to audit the "cleanliness" of the gradients produced by PR #3138? This might reveal why policies trained on these distances sometimes struggle with the reality gap.Specifically, for PR #3138, how does the team plan to handle gradient discontinuities during the transition from non-penetration to contact in mesh-to-mesh scenarios? We suspect NARH analysis could provide a quantitative metric for this stability.

Ref: https://github.com/ZC502/SIPA.git

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MJX: add differentiable distance calculation

2 participants