update pose inverse for goalset

This commit is contained in:
Balakumar Sundaralingam
2024-02-23 16:19:18 -08:00
parent f25281e930
commit 286b3820a5
3 changed files with 39 additions and 23 deletions

View File

@@ -941,7 +941,6 @@ class PoseInverse(torch.autograd.Function):
adj_position: torch.Tensor,
adj_quaternion: torch.Tensor,
):
b, _ = position.shape
if out_position is None:
out_position = torch.zeros_like(position)
@@ -951,7 +950,8 @@ class PoseInverse(torch.autograd.Function):
adj_position = torch.zeros_like(position)
if adj_quaternion is None:
adj_quaternion = torch.zeros_like(quaternion)
b, _ = position.view(-1, 3).shape
ctx.b = b
init_warp()
ctx.save_for_backward(
position,
@@ -961,7 +961,6 @@ class PoseInverse(torch.autograd.Function):
adj_position,
adj_quaternion,
)
ctx.b = b
wp.launch(
kernel=compute_pose_inverse,
@@ -976,9 +975,6 @@ class PoseInverse(torch.autograd.Function):
],
stream=wp.stream_from_torch(position.device),
)
# remove close to zero values:
# out_position[torch.abs(out_position)<1e-8] = 0.0
# out_quaternion[torch.abs(out_quaternion)<1e-8] = 0.0
return out_position, out_quaternion