From a22a2fddf6c11afd1c212e3329bfa443d6410357 Mon Sep 17 00:00:00 2001 From: Yuheng Zhi Date: Thu, 26 Sep 2024 17:46:17 -0700 Subject: [PATCH] fix quaternion's gradients in PoseInverse, and a few other warp transform kernels having the same issue --- src/curobo/geom/transform.py | 36 +++++++++--------------------------- 1 file changed, 9 insertions(+), 27 deletions(-) diff --git a/src/curobo/geom/transform.py b/src/curobo/geom/transform.py index dacc2d2..2757e5c 100644 --- a/src/curobo/geom/transform.py +++ b/src/curobo/geom/transform.py @@ -427,11 +427,7 @@ def compute_pose_inverse( # write pt: out_q = wp.transform_get_rotation(t_3) - out_v = wp.vec4() - out_v[0] = out_q[3] # out_q[3] - out_v[1] = out_q[0] # [0] - out_v[2] = out_q[1] # wp.extract(out_q, 1) - out_v[3] = out_q[2] # wp.extract(out_q, 2) + out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2]) out_position[b_idx] = wp.transform_get_translation(t_3) out_quat[b_idx] = out_v @@ -453,11 +449,7 @@ def compute_matrix_to_quat( # create a transform from a vector/quaternion: out_q = wp.quat_from_matrix(in_m) - out_v = wp.vec4() - out_v[0] = out_q[3] # wp.extract(out_q, 3) - out_v[1] = out_q[0] # wp.extract(out_q, 0) - out_v[2] = out_q[1] # wp.extract(out_q, 1) - out_v[3] = out_q[2] # wp.extract(out_q, 2) + out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2]) # write pt: out_quat[b_idx] = out_v @@ -562,11 +554,7 @@ def compute_batch_pose_multipy( # write pt: out_q = wp.transform_get_rotation(t_3) - out_v = wp.vec4() - out_v[0] = out_q[3] - out_v[1] = out_q[0] - out_v[2] = out_q[1] - out_v[3] = out_q[2] + out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2]) out_position[b_idx] = wp.transform_get_translation(t_3) out_quat[b_idx] = out_v @@ -626,11 +614,7 @@ def compute_pose_multipy( # write pt: out_q = wp.transform_get_rotation(t_3) - out_v = wp.vec4() - out_v[0] = out_q[3] - out_v[1] = out_q[0] - out_v[2] = out_q[1] - out_v[3] = out_q[2] + out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2]) out_position[b_idx] = wp.transform_get_translation(t_3) out_quat[b_idx] = out_v @@ -850,7 +834,7 @@ class BatchTransformPose(torch.autograd.Function): adj_position2: torch.Tensor, adj_quaternion2: torch.Tensor, ): - b, _ = position.shape + b, _ = position.view(-1, 3).shape if out_position is None: out_position = torch.zeros_like(position2) @@ -977,7 +961,7 @@ class BatchTransformPose(torch.autograd.Function): g_p2 = adj_position2 if ctx.needs_input_grad[3]: g_q2 = adj_quaternion2 - return g_p1, g_q1, g_p2, g_q2, None, None, None, None + return g_p1, g_q1, g_p2, g_q2, None, None, None, None, None, None class TransformPose(torch.autograd.Function): @@ -997,7 +981,7 @@ class TransformPose(torch.autograd.Function): adj_position2: torch.Tensor, adj_quaternion2: torch.Tensor, ): - b, _ = position2.shape + b, _ = position2.view(-1, 3).shape init_warp() if out_position is None: out_position = torch.zeros_like(position2) @@ -1123,7 +1107,7 @@ class TransformPose(torch.autograd.Function): g_p2 = adj_position2 if ctx.needs_input_grad[3]: g_q2 = adj_quaternion2 - return g_p1, g_q1, g_p2, g_q2, None, None, None, None + return g_p1, g_q1, g_p2, g_q2, None, None, None, None, None, None class PoseInverse(torch.autograd.Function): @@ -1223,8 +1207,6 @@ class PoseInverse(torch.autograd.Function): adj_inputs=[ None, None, - None, - None, ], adj_outputs=[ None, @@ -1239,7 +1221,7 @@ class PoseInverse(torch.autograd.Function): if ctx.needs_input_grad[1]: g_q1 = adj_quaternion - return g_p1, g_q1, None, None + return g_p1, g_q1, None, None, None, None class QuatToMatrix(torch.autograd.Function):