Merge pull request #522 from williamshen-nz/main

Fix backward in SelfCollisionDistance
This commit is contained in:
Balakumar Sundaralingam
2025-06-09 14:50:23 -07:00
committed by GitHub

View File

@@ -123,7 +123,7 @@ class SelfCollisionDistance(torch.autograd.Function):
if ctx.needs_input_grad[3]: if ctx.needs_input_grad[3]:
(g_vec,) = ctx.saved_tensors (g_vec,) = ctx.saved_tensors
if ctx.return_loss: if ctx.return_loss:
g_vec = g_vec * grad_out_distance.unsqueeze(1) g_vec = g_vec * grad_out_distance.view(*g_vec.shape[:2], 1, 1)
sphere_grad = g_vec sphere_grad = g_vec
return None, None, None, sphere_grad, None, None, None, None, None, None, None, None return None, None, None, sphere_grad, None, None, None, None, None, None, None, None