MJX: add differentiable distance calculation#3138
MJX: add differentiable distance calculation#3138aarushk09 wants to merge 1 commit intogoogle-deepmind:mainfrom
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Replace hard-coded dist=1 sentinel values in collision functions with actual signed distance values. This preserves the gradient chain through JAX even when geoms are not colliding, enabling differentiable distance computation for optimization and learning applications. Fixes google-deepmind#3131
|
@aarushk09 @egordon , The "Numerical-Only" Differentiability Issue: In our audits using SIPA (Simulation Integrity & Physics Auditor)—developed based on the Non-Associative Residual Hypothesis (NARH)—we found that standard collision pipelines (like GJK/EPA adaptations) can suffer from "Algebraic Leakage" when resolving redundant contact manifolds. This manifests as "Gradient Noise," which prevents optimizers from converging to physically consistent solutions.The Discrete Associator $A(a,b,c;s) = ((\Psi_a \circ \Psi_b) \circ \Psi_c)(s) - (\Psi_a \circ (\Psi_b \circ \Psi_c))(s)$ serves as a "Physical Consistency Probe." Observations from 7-DoF Benchmarks (e.g., KUKA iiwa): A Potential Path Forward: Is the MJX team interested in using SIPA to audit the "cleanliness" of the gradients produced by PR #3138? This might reveal why policies trained on these distances sometimes struggle with the reality gap.Specifically, for PR #3138, how does the team plan to handle gradient discontinuities during the transition from non-penetration to contact in mesh-to-mesh scenarios? We suspect NARH analysis could provide a quantitative metric for this stability. |

Summary
Replace hard-coded dist=1 sentinel values in MJX collision functions with actual signed distance values. This preserves the gradient chain through JAX even when geoms are not colliding, enabling differentiable distance computation.
Changes
All changes are in mjx/mujoco/mjx/_src/collision_convex.py (17 insertions, 15 deletions):
Testing
Fixes #3131