Fix backward in SelfCollisionDistance

This commit is contained in:
William Shen
2024-12-20 13:49:38 +13:00
parent 2fbffc3522
commit e2ef88edb7

View File

@@ -123,7 +123,7 @@ class SelfCollisionDistance(torch.autograd.Function):
if ctx.needs_input_grad[3]: if ctx.needs_input_grad[3]:
(g_vec,) = ctx.saved_tensors (g_vec,) = ctx.saved_tensors
if ctx.return_loss: if ctx.return_loss:
g_vec = g_vec * grad_out_distance.unsqueeze(1) g_vec = g_vec * grad_out_distance.view(*g_vec.shape[:2], 1, 1)
sphere_grad = g_vec sphere_grad = g_vec
return None, None, None, sphere_grad, None, None, None, None, None, None, None, None return None, None, None, sphere_grad, None, None, None, None, None, None, None, None