fix quaternion's gradients in PoseInverse, and a few other warp transform kernels having the same issue
This commit is contained in:
@@ -427,11 +427,7 @@ def compute_pose_inverse(
|
|||||||
# write pt:
|
# write pt:
|
||||||
out_q = wp.transform_get_rotation(t_3)
|
out_q = wp.transform_get_rotation(t_3)
|
||||||
|
|
||||||
out_v = wp.vec4()
|
out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2])
|
||||||
out_v[0] = out_q[3] # out_q[3]
|
|
||||||
out_v[1] = out_q[0] # [0]
|
|
||||||
out_v[2] = out_q[1] # wp.extract(out_q, 1)
|
|
||||||
out_v[3] = out_q[2] # wp.extract(out_q, 2)
|
|
||||||
|
|
||||||
out_position[b_idx] = wp.transform_get_translation(t_3)
|
out_position[b_idx] = wp.transform_get_translation(t_3)
|
||||||
out_quat[b_idx] = out_v
|
out_quat[b_idx] = out_v
|
||||||
@@ -453,11 +449,7 @@ def compute_matrix_to_quat(
|
|||||||
# create a transform from a vector/quaternion:
|
# create a transform from a vector/quaternion:
|
||||||
out_q = wp.quat_from_matrix(in_m)
|
out_q = wp.quat_from_matrix(in_m)
|
||||||
|
|
||||||
out_v = wp.vec4()
|
out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2])
|
||||||
out_v[0] = out_q[3] # wp.extract(out_q, 3)
|
|
||||||
out_v[1] = out_q[0] # wp.extract(out_q, 0)
|
|
||||||
out_v[2] = out_q[1] # wp.extract(out_q, 1)
|
|
||||||
out_v[3] = out_q[2] # wp.extract(out_q, 2)
|
|
||||||
# write pt:
|
# write pt:
|
||||||
out_quat[b_idx] = out_v
|
out_quat[b_idx] = out_v
|
||||||
|
|
||||||
@@ -562,11 +554,7 @@ def compute_batch_pose_multipy(
|
|||||||
# write pt:
|
# write pt:
|
||||||
out_q = wp.transform_get_rotation(t_3)
|
out_q = wp.transform_get_rotation(t_3)
|
||||||
|
|
||||||
out_v = wp.vec4()
|
out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2])
|
||||||
out_v[0] = out_q[3]
|
|
||||||
out_v[1] = out_q[0]
|
|
||||||
out_v[2] = out_q[1]
|
|
||||||
out_v[3] = out_q[2]
|
|
||||||
|
|
||||||
out_position[b_idx] = wp.transform_get_translation(t_3)
|
out_position[b_idx] = wp.transform_get_translation(t_3)
|
||||||
out_quat[b_idx] = out_v
|
out_quat[b_idx] = out_v
|
||||||
@@ -626,11 +614,7 @@ def compute_pose_multipy(
|
|||||||
# write pt:
|
# write pt:
|
||||||
out_q = wp.transform_get_rotation(t_3)
|
out_q = wp.transform_get_rotation(t_3)
|
||||||
|
|
||||||
out_v = wp.vec4()
|
out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2])
|
||||||
out_v[0] = out_q[3]
|
|
||||||
out_v[1] = out_q[0]
|
|
||||||
out_v[2] = out_q[1]
|
|
||||||
out_v[3] = out_q[2]
|
|
||||||
|
|
||||||
out_position[b_idx] = wp.transform_get_translation(t_3)
|
out_position[b_idx] = wp.transform_get_translation(t_3)
|
||||||
out_quat[b_idx] = out_v
|
out_quat[b_idx] = out_v
|
||||||
@@ -850,7 +834,7 @@ class BatchTransformPose(torch.autograd.Function):
|
|||||||
adj_position2: torch.Tensor,
|
adj_position2: torch.Tensor,
|
||||||
adj_quaternion2: torch.Tensor,
|
adj_quaternion2: torch.Tensor,
|
||||||
):
|
):
|
||||||
b, _ = position.shape
|
b, _ = position.view(-1, 3).shape
|
||||||
|
|
||||||
if out_position is None:
|
if out_position is None:
|
||||||
out_position = torch.zeros_like(position2)
|
out_position = torch.zeros_like(position2)
|
||||||
@@ -977,7 +961,7 @@ class BatchTransformPose(torch.autograd.Function):
|
|||||||
g_p2 = adj_position2
|
g_p2 = adj_position2
|
||||||
if ctx.needs_input_grad[3]:
|
if ctx.needs_input_grad[3]:
|
||||||
g_q2 = adj_quaternion2
|
g_q2 = adj_quaternion2
|
||||||
return g_p1, g_q1, g_p2, g_q2, None, None, None, None
|
return g_p1, g_q1, g_p2, g_q2, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class TransformPose(torch.autograd.Function):
|
class TransformPose(torch.autograd.Function):
|
||||||
@@ -997,7 +981,7 @@ class TransformPose(torch.autograd.Function):
|
|||||||
adj_position2: torch.Tensor,
|
adj_position2: torch.Tensor,
|
||||||
adj_quaternion2: torch.Tensor,
|
adj_quaternion2: torch.Tensor,
|
||||||
):
|
):
|
||||||
b, _ = position2.shape
|
b, _ = position2.view(-1, 3).shape
|
||||||
init_warp()
|
init_warp()
|
||||||
if out_position is None:
|
if out_position is None:
|
||||||
out_position = torch.zeros_like(position2)
|
out_position = torch.zeros_like(position2)
|
||||||
@@ -1123,7 +1107,7 @@ class TransformPose(torch.autograd.Function):
|
|||||||
g_p2 = adj_position2
|
g_p2 = adj_position2
|
||||||
if ctx.needs_input_grad[3]:
|
if ctx.needs_input_grad[3]:
|
||||||
g_q2 = adj_quaternion2
|
g_q2 = adj_quaternion2
|
||||||
return g_p1, g_q1, g_p2, g_q2, None, None, None, None
|
return g_p1, g_q1, g_p2, g_q2, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class PoseInverse(torch.autograd.Function):
|
class PoseInverse(torch.autograd.Function):
|
||||||
@@ -1223,8 +1207,6 @@ class PoseInverse(torch.autograd.Function):
|
|||||||
adj_inputs=[
|
adj_inputs=[
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
],
|
],
|
||||||
adj_outputs=[
|
adj_outputs=[
|
||||||
None,
|
None,
|
||||||
@@ -1239,7 +1221,7 @@ class PoseInverse(torch.autograd.Function):
|
|||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
g_q1 = adj_quaternion
|
g_q1 = adj_quaternion
|
||||||
|
|
||||||
return g_p1, g_q1, None, None
|
return g_p1, g_q1, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class QuatToMatrix(torch.autograd.Function):
|
class QuatToMatrix(torch.autograd.Function):
|
||||||
|
|||||||
Reference in New Issue
Block a user