Fix backward in SelfCollisionDistance
This commit is contained in:
@@ -123,7 +123,7 @@ class SelfCollisionDistance(torch.autograd.Function):
|
|||||||
if ctx.needs_input_grad[3]:
|
if ctx.needs_input_grad[3]:
|
||||||
(g_vec,) = ctx.saved_tensors
|
(g_vec,) = ctx.saved_tensors
|
||||||
if ctx.return_loss:
|
if ctx.return_loss:
|
||||||
g_vec = g_vec * grad_out_distance.unsqueeze(1)
|
g_vec = g_vec * grad_out_distance.view(*g_vec.shape[:2], 1, 1)
|
||||||
sphere_grad = g_vec
|
sphere_grad = g_vec
|
||||||
return None, None, None, sphere_grad, None, None, None, None, None, None, None, None
|
return None, None, None, sphere_grad, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user