diff --git a/mjx/mujoco/mjx/_src/collision_convex.py b/mjx/mujoco/mjx/_src/collision_convex.py index 0d60c4c973..6a0cc74367 100644 --- a/mjx/mujoco/mjx/_src/collision_convex.py +++ b/mjx/mujoco/mjx/_src/collision_convex.py @@ -153,7 +153,7 @@ def plane_convex(plane: GeomInfo, convex: ConvexInfo) -> Collision: frame = jp.stack([math.make_frame(n)] * 4, axis=0) unique = jp.tril(idx == idx[:, None]).sum(axis=1) == 1 - dist = jp.where(unique, -support[idx], 1) + dist = jp.where(unique, -support[idx], jp.abs(support[idx]).max()) pos = pos - 0.5 * dist[:, None] * n return dist, pos, frame @@ -214,7 +214,7 @@ def get_support(faces, normal): d *= sign spt = sphere_pos + n * sphere.size[0] - dist = jp.where(has_separating_axis, 1.0, d - sphere.size[0]) + dist = d - sphere.size[0] pos = (pt + spt) * 0.5 # Go back to world frame. @@ -280,9 +280,8 @@ def get_support(face, normal): # Create variables for the face contact. pos = (cap_pts_clipped + face_pts) * 0.5 contact_normal = -jp.stack([normal] * 2, 0) - face_penetration = jp.where( - mask & has_support, jp.dot(face_pts - cap_pts_clipped, normal), -1 - ) + face_dist = jp.dot(face_pts - cap_pts_clipped, normal) + face_penetration = jp.where(mask & has_support, face_dist, jp.minimum(face_dist, 0)) # Pick a potential shallow edge contact. def get_edge_axis(edge): @@ -316,7 +315,7 @@ def get_edge_axis(edge): edge_face_normals = edge_face_normal[e_idx] edge_voronoi_front = ((edge_face_normals @ edge_axis) < 0).all() shallow = ~degenerate_edge_dir & edge_voronoi_front - edge_penetration = jp.where(shallow, cap.size[0] - edge_dist, -1) + edge_penetration = cap.size[0] - edge_dist # Determine edge contact position. edge_pos = ( @@ -327,7 +326,8 @@ def get_edge_axis(edge): ) & ~degenerate_edge_dir min_face_penetration = face_penetration.min() has_edge_contact = ( - (edge_penetration > 0) + shallow + & (edge_penetration > 0) # prefer edge contact if the edge is smaller than face penetration & jp.where( min_face_penetration > 0, @@ -351,7 +351,9 @@ def get_edge_axis(edge): n = n @ convex.mat.T dist = -jp.where( - has_edge_contact, jp.array([edge_penetration, -1]), face_penetration + has_edge_contact, + jp.array([edge_penetration, -edge_penetration]), + face_penetration, ) return dist, pos, n @@ -577,7 +579,7 @@ def _create_contact_manifold( penetration_dir = jp.take(poly_incident, best, axis=0) - contact_pts penetration = penetration_dir.dot(-clipping_norm) - dist = jp.where(mask_pts, -penetration, jp.ones_like(penetration)) + dist = jp.where(mask_pts, -penetration, jp.abs(penetration)) pos = contact_pts normal = -jp.stack([sep_axis] * 4, 0) return dist, pos, normal @@ -677,7 +679,7 @@ def get_support(axis, is_degenerate): idx = dist.argmin() dist = jp.where( is_edge_contact, - jp.array([dist[idx], 1, 1, 1]), + jp.array([dist[idx], jp.abs(dist[idx]), jp.abs(dist[idx]), jp.abs(dist[idx])]), dist, ) pos = jp.where(is_edge_contact, jp.tile(pos[idx], (4, 1)), pos) @@ -806,7 +808,7 @@ def get_support(axis): incident_face_norm, -best_axis, ) - dist = jp.where(is_face_separating, 1.0, dist) + dist = jp.where(is_face_separating, jp.abs(dist), dist) # Handle edge separating axes by checking all edge pairs. a_idx = jp.tile(jp.arange(edges_a.shape[0]), reps=edges_b.shape[0]) @@ -856,7 +858,7 @@ def get_normals(a_dir, a_pt, b_dir): normal = jp.where(is_edge_contact, edge_axes[best_edge_idx], normal) dist = jp.where( is_edge_contact, - jp.array([best_edge_dist, 1, 1, 1]), + jp.array([best_edge_dist, jp.abs(best_edge_dist), jp.abs(best_edge_dist), jp.abs(best_edge_dist)]), dist, ) a_closest, b_closest = math.closest_segment_to_segment_points( @@ -1060,7 +1062,7 @@ def hfield_sphere( # zero out non-unique contacts unique = jp.tril(idx == idx[:, None]).sum(axis=1) == 1 - dist = jp.where(unique, dist, 1) + dist = jp.where(unique, dist, jp.max(dist)) # back to world frame, _hfield_collision returns collision in hfield frame pos = jax.vmap(lambda p: h.mat @ p + h.pos)(pos) @@ -1084,7 +1086,7 @@ def hfield_capsule( # zero out non-unique contacts unique = jp.tril(idx == idx[:, None]).sum(axis=1) == 1 - dist = jp.where(unique, dist, 1) + dist = jp.where(unique, dist, jp.max(dist)) # back to world frame, _hfield_collision returns collision in hfield frame pos = jax.vmap(lambda p: h.mat @ p + h.pos)(pos) @@ -1108,7 +1110,7 @@ def hfield_convex( # zero out non-unique contacts unique = jp.tril(idx == idx[:, None]).sum(axis=1) == 1 - dist = jp.where(unique, dist, 1) + dist = jp.where(unique, dist, jp.max(dist)) # back to world frame, _hfield_collision returns collision in hfield frame pos = jax.vmap(lambda p: h.mat @ p + h.pos)(pos)