diff --git a/src/curobo/curobolib/geom.py b/src/curobo/curobolib/geom.py index b237cd0..35da6ac 100644 --- a/src/curobo/curobolib/geom.py +++ b/src/curobo/curobolib/geom.py @@ -123,7 +123,7 @@ class SelfCollisionDistance(torch.autograd.Function): if ctx.needs_input_grad[3]: (g_vec,) = ctx.saved_tensors if ctx.return_loss: - g_vec = g_vec * grad_out_distance.unsqueeze(1) + g_vec = g_vec * grad_out_distance.view(*g_vec.shape[:2], 1, 1) sphere_grad = g_vec return None, None, None, sphere_grad, None, None, None, None, None, None, None, None