Files
gen_data_curobo/src/curobo/geom/transform.py
Balakumar Sundaralingam 07e6ccfc91 release repository
2023-10-26 04:17:19 -07:00

1210 lines
37 KiB
Python

#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from typing import Optional
# Third Party
import torch
import warp as wp
# CuRobo
from curobo.curobolib.kinematics import rotation_matrix_to_quaternion
from curobo.util.logger import log_error
from curobo.util.warp import init_warp
def transform_points(
position, quaternion, points, out_points=None, out_gp=None, out_gq=None, out_gpt=None
):
if out_points is None:
out_points = torch.zeros((points.shape[0], 3), device=points.device, dtype=points.dtype)
if out_gp is None:
out_gp = torch.zeros((position.shape[0], 3), device=position.device)
if out_gq is None:
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device)
if out_gpt is None:
out_gpt = torch.zeros((points.shape[0], 3), device=position.device)
out_points = TransformPoint.apply(
position, quaternion, points, out_points, out_gp, out_gq, out_gpt
)
return out_points
def batch_transform_points(
position, quaternion, points, out_points=None, out_gp=None, out_gq=None, out_gpt=None
):
if out_points is None:
out_points = torch.zeros(
(points.shape[0], points.shape[1], 3), device=points.device, dtype=points.dtype
)
if out_gp is None:
out_gp = torch.zeros((position.shape[0], 3), device=position.device)
if out_gq is None:
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device)
if out_gpt is None:
out_gpt = torch.zeros((points.shape[0], points.shape[1], 3), device=position.device)
out_points = BatchTransformPoint.apply(
position, quaternion, points, out_points, out_gp, out_gq, out_gpt
)
return out_points
@torch.jit.script
def get_inv_transform(w_rot_c, w_trans_c):
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
c_rot_w = w_rot_c.transpose(-1, -2)
c_trans_w = -1.0 * (c_rot_w @ w_trans_c.unsqueeze(-1)).squeeze(-1)
return c_rot_w, c_trans_w
@torch.jit.script
def transform_point_inverse(point, rot, trans):
# type: (Tensor, Tensor, Tensor) -> Tensor
# new_point = (rot @ (point).unsqueeze(-1)).squeeze(-1) + trans
n_rot, n_trans = get_inv_transform(rot, trans)
new_point = (point @ n_rot.transpose(-1, -2)) + n_trans
return new_point
def matrix_to_quaternion(matrix, out_quat=None, adj_matrix=None):
matrix = matrix.view(-1, 3, 3)
out_quat = MatrixToQuaternion.apply(matrix, out_quat, adj_matrix)
# out_quat = cuda_matrix_to_quaternion(matrix)
return out_quat
def cuda_matrix_to_quaternion(matrix):
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4). [qw, qx,qy,qz]
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
# account for different batch shapes here:
in_shape = matrix.shape
mat_in = matrix.view(-1, 3, 3)
out_quat = torch.zeros((mat_in.shape[0], 4), device=matrix.device, dtype=matrix.dtype)
out_quat = rotation_matrix_to_quaternion(matrix, out_quat)
out_shape = list(in_shape[:-2]) + [4]
out_quat = out_quat.view(out_shape)
return out_quat
def quaternion_to_matrix(quaternions, out_mat=None, adj_quaternion=None):
# return torch_quaternion_to_matrix(quaternions)
out_mat = QuatToMatrix.apply(quaternions, out_mat, adj_quaternion)
return out_mat
def torch_quaternion_to_matrix(quaternions):
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
quaternions = torch.as_tensor(quaternions)
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def pose_to_matrix(
position: torch.Tensor, quaternion: torch.Tensor, out_matrix: Optional[torch.Tensor] = None
):
if out_matrix is None:
if len(position.shape) == 2:
out_matrix = torch.zeros(
(position.shape[0], 4, 4), device=position.device, dtype=position.dtype
)
else:
out_matrix = torch.zeros(
(position.shape[0], position.shape[1], 4, 4),
device=position.device,
dtype=position.dtype,
)
out_matrix[..., 3, 3] = 1.0
out_matrix[..., :3, 3] = position
out_matrix[..., :3, :3] = quaternion_to_matrix(quaternion)
return out_matrix
def pose_multiply(
position,
quaternion,
position2,
quaternion2,
out_position=None,
out_quaternion=None,
adj_pos=None,
adj_quat=None,
adj_pos2=None,
adj_quat2=None,
):
if position.shape == position2.shape:
out_position, out_quaternion = BatchTransformPose.apply(
position,
quaternion,
position2,
quaternion2,
out_position,
out_quaternion,
adj_pos,
adj_quat,
adj_pos2,
adj_quat2,
)
elif position.shape[0] == 1 and position2.shape[0] > 1:
out_position, out_quaternion = TransformPose.apply(
position,
quaternion,
position2,
quaternion2,
out_position,
out_quaternion,
adj_pos,
adj_quat,
adj_pos2,
adj_quat2,
)
else:
log_error("shapes not supported")
return out_position, out_quaternion
def pose_inverse(
position,
quaternion,
out_position=None,
out_quaternion=None,
adj_pos=None,
adj_quat=None,
):
out_position, out_quaternion = PoseInverse.apply(
position,
quaternion,
out_position,
out_quaternion,
adj_pos,
adj_quat,
)
return out_position, out_quaternion
@wp.kernel
def compute_transform_point(
position: wp.array(dtype=wp.vec3),
quat: wp.array(dtype=wp.vec4),
pt: wp.array(dtype=wp.vec3),
n_pts: wp.int32,
n_poses: wp.int32,
out_pt: wp.array(dtype=wp.vec3),
): # given n,3 points and b poses, get b,n,3 transformed points
# we tile as
tid = wp.tid()
b_idx = tid / (n_pts)
p_idx = tid - (b_idx * n_pts)
# read data:
in_position = position[b_idx]
in_quat = quat[b_idx]
in_pt = pt[p_idx]
# read point
# create a transform from a vector/quaternion:
t = wp.transform(in_position, wp.quaternion(in_quat[1], in_quat[2], in_quat[3], in_quat[0]))
# transform a point
p = wp.transform_point(t, in_pt)
# write pt:
out_pt[b_idx * n_pts + p_idx] = p
@wp.kernel
def compute_batch_transform_point(
position: wp.array(dtype=wp.vec3),
quat: wp.array(dtype=wp.vec4),
pt: wp.array(dtype=wp.vec3),
n_pts: wp.int32,
n_poses: wp.int32,
out_pt: wp.array(dtype=wp.vec3),
): # given n,3 points and b poses, get b,n,3 transformed points
# we tile as
tid = wp.tid()
b_idx = tid / (n_pts)
p_idx = tid - (b_idx * n_pts)
# read data:
in_position = position[b_idx]
in_quat = quat[b_idx]
in_pt = pt[b_idx * n_pts + p_idx]
# read point
# create a transform from a vector/quaternion:
t = wp.transform(in_position, wp.quaternion(in_quat[1], in_quat[2], in_quat[3], in_quat[0]))
# transform a point
p = wp.transform_point(t, in_pt)
# write pt:
out_pt[b_idx * n_pts + p_idx] = p
@wp.kernel
def compute_batch_pose_multipy(
position: wp.array(dtype=wp.vec3),
quat: wp.array(dtype=wp.vec4),
position2: wp.array(dtype=wp.vec3),
quat2: wp.array(dtype=wp.vec4),
out_position: wp.array(dtype=wp.vec3),
out_quat: wp.array(dtype=wp.vec4),
): # b pose_1 and b pose_2, compute pose_1 * pose_2
b_idx = wp.tid()
# read data:
in_position = position[b_idx]
in_quat = quat[b_idx]
in_position2 = position2[b_idx]
in_quat2 = quat2[b_idx]
# read point
# create a transform from a vector/quaternion:
t_1 = wp.transform(in_position, wp.quaternion(in_quat[1], in_quat[2], in_quat[3], in_quat[0]))
t_2 = wp.transform(
in_position2, wp.quaternion(in_quat2[1], in_quat2[2], in_quat2[3], in_quat2[0])
)
# transform a point
t_3 = wp.transform_multiply(t_1, t_2)
# write pt:
out_q = wp.transform_get_rotation(t_3)
out_v = wp.vec4()
out_v[0] = out_q[3]
out_v[1] = out_q[0]
out_v[2] = out_q[1]
out_v[3] = out_q[2]
out_position[b_idx] = wp.transform_get_translation(t_3)
out_quat[b_idx] = out_v
@wp.kernel
def compute_pose_inverse(
position: wp.array(dtype=wp.vec3),
quat: wp.array(dtype=wp.vec4),
out_position: wp.array(dtype=wp.vec3),
out_quat: wp.array(dtype=wp.vec4),
): # b pose_1 and b pose_2, compute pose_1 * pose_2
b_idx = wp.tid()
# read data:
in_position = position[b_idx]
in_quat = quat[b_idx]
# read point
# create a transform from a vector/quaternion:
t_1 = wp.transform(in_position, wp.quaternion(in_quat[1], in_quat[2], in_quat[3], in_quat[0]))
t_3 = wp.transform_inverse(t_1)
# write pt:
out_q = wp.transform_get_rotation(t_3)
out_v = wp.vec4()
out_v[0] = wp.index(out_q, 3)
out_v[1] = wp.index(out_q, 0)
out_v[2] = wp.index(out_q, 1)
out_v[3] = wp.index(out_q, 2)
out_position[b_idx] = wp.transform_get_translation(t_3)
out_quat[b_idx] = out_v
@wp.kernel
def compute_quat_to_matrix(
quat: wp.array(dtype=wp.vec4),
out_mat: wp.array(dtype=wp.mat33),
):
# b pose_1 and b pose_2, compute pose_1 * pose_2
b_idx = wp.tid()
# read data:
in_quat = quat[b_idx]
# read point
# create a transform from a vector/quaternion:
q_1 = wp.quaternion(in_quat[1], in_quat[2], in_quat[3], in_quat[0])
m_1 = wp.quat_to_matrix(q_1)
# write pt:
out_mat[b_idx] = m_1
@wp.kernel
def compute_matrix_to_quat(
in_mat: wp.array(dtype=wp.mat33),
out_quat: wp.array(dtype=wp.vec4),
):
# b pose_1 and b pose_2, compute pose_1 * pose_2
b_idx = wp.tid()
# read data:
in_m = in_mat[b_idx]
# read point
# create a transform from a vector/quaternion:
out_q = wp.quat_from_matrix(in_m)
out_v = wp.vec4()
out_v[0] = wp.index(out_q, 3)
out_v[1] = wp.index(out_q, 0)
out_v[2] = wp.index(out_q, 1)
out_v[3] = wp.index(out_q, 2)
# write pt:
out_quat[b_idx] = out_v
@wp.kernel
def compute_pose_multipy(
position: wp.array(dtype=wp.vec3),
quat: wp.array(dtype=wp.vec4),
position2: wp.array(dtype=wp.vec3),
quat2: wp.array(dtype=wp.vec4),
out_position: wp.array(dtype=wp.vec3),
out_quat: wp.array(dtype=wp.vec4),
): # b pose_1 and b pose_2, compute pose_1 * pose_2
b_idx = wp.tid()
# read data:
in_position = position[0]
in_quat = quat[0]
in_position2 = position2[b_idx]
in_quat2 = quat2[b_idx]
# read point
# create a transform from a vector/quaternion:
t_1 = wp.transform(in_position, wp.quaternion(in_quat[1], in_quat[2], in_quat[3], in_quat[0]))
t_2 = wp.transform(
in_position2, wp.quaternion(in_quat2[1], in_quat2[2], in_quat2[3], in_quat2[0])
)
# transform a point
t_3 = wp.transform_multiply(t_1, t_2)
# write pt:
out_q = wp.transform_get_rotation(t_3)
out_v = wp.vec4()
out_v[0] = out_q[3]
out_v[1] = out_q[0]
out_v[2] = out_q[1]
out_v[3] = out_q[2]
out_position[b_idx] = wp.transform_get_translation(t_3)
out_quat[b_idx] = out_v
class TransformPoint(torch.autograd.Function):
@staticmethod
def forward(
ctx,
position: torch.Tensor,
quaternion: torch.Tensor,
points: torch.Tensor,
out_points: torch.Tensor,
adj_position: torch.Tensor,
adj_quaternion: torch.Tensor,
adj_points: torch.Tensor,
):
n, _ = out_points.shape
init_warp()
ctx.save_for_backward(
position, quaternion, points, out_points, adj_position, adj_quaternion, adj_points
)
b = 1
ctx.b = b
ctx.n = n
wp.launch(
kernel=compute_transform_point,
dim=b * n,
inputs=[
wp.from_torch(position.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
wp.from_torch(quaternion.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
wp.from_torch(points.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
n,
b,
],
outputs=[wp.from_torch(out_points.view(-1, 3), dtype=wp.vec3)],
stream=wp.stream_from_torch(position.device),
)
return out_points
@staticmethod
def backward(ctx, grad_output):
(
position,
quaternion,
points,
out_points,
adj_position,
adj_quaternion,
adj_points,
) = ctx.saved_tensors
adj_position = 0.0 * adj_position
adj_quaternion = 0.0 * adj_quaternion
adj_points = 0.0 * adj_points
wp_adj_out_points = wp.from_torch(grad_output.view(-1, 3).contiguous(), dtype=wp.vec3)
wp_adj_points = wp.from_torch(adj_points, dtype=wp.vec3)
wp_adj_position = wp.from_torch(adj_position, dtype=wp.vec3)
wp_adj_quat = wp.from_torch(adj_quaternion, dtype=wp.vec4)
wp.launch(
kernel=compute_transform_point,
dim=ctx.b * ctx.n,
inputs=[
wp.from_torch(
position.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_position
),
wp.from_torch(quaternion.view(-1, 4).contiguous(), dtype=wp.vec4, grad=wp_adj_quat),
wp.from_torch(points.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_points),
ctx.n,
ctx.b,
],
outputs=[
wp.from_torch(
out_points.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_out_points
),
],
adj_inputs=[
None,
None,
None,
ctx.n,
ctx.b,
],
adj_outputs=[
None,
],
stream=wp.stream_from_torch(grad_output.device),
adjoint=True,
)
g_p = g_q = g_pt = None
if ctx.needs_input_grad[0]:
g_p = adj_position
if ctx.needs_input_grad[1]:
g_q = adj_quaternion
if ctx.needs_input_grad[2]:
g_pt = adj_points
return g_p, g_q, g_pt, None, None, None, None
class BatchTransformPoint(torch.autograd.Function):
@staticmethod
def forward(
ctx,
position: torch.Tensor,
quaternion: torch.Tensor,
points: torch.Tensor,
out_points: torch.Tensor,
adj_position: torch.Tensor,
adj_quaternion: torch.Tensor,
adj_points: torch.Tensor,
):
b, n, _ = out_points.shape
init_warp()
points = points.detach()
ctx.save_for_backward(
position, quaternion, points, out_points, adj_position, adj_quaternion, adj_points
)
ctx.b = b
ctx.n = n
wp.launch(
kernel=compute_batch_transform_point,
dim=b * n,
inputs=[
wp.from_torch(position.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
wp.from_torch(quaternion.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
wp.from_torch(points.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
n,
b,
],
outputs=[wp.from_torch(out_points.view(-1, 3).contiguous(), dtype=wp.vec3)],
stream=wp.stream_from_torch(position.device),
)
return out_points
@staticmethod
def backward(ctx, grad_output):
(
position,
quaternion,
points,
out_points,
adj_position,
adj_quaternion,
adj_points,
) = ctx.saved_tensors
init_warp()
# print(adj_quaternion.shape)
wp_adj_out_points = wp.from_torch(grad_output.view(-1, 3).contiguous(), dtype=wp.vec3)
adj_position = 0.0 * adj_position
adj_quaternion = 0.0 * adj_quaternion
adj_points = 0.0 * adj_points
wp_adj_points = wp.from_torch(adj_points.view(-1, 3), dtype=wp.vec3)
wp_adj_position = wp.from_torch(adj_position.view(-1, 3), dtype=wp.vec3)
wp_adj_quat = wp.from_torch(adj_quaternion.view(-1, 4), dtype=wp.vec4)
wp.launch(
kernel=compute_batch_transform_point,
dim=ctx.b * ctx.n,
inputs=[
wp.from_torch(
position.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_position
),
wp.from_torch(quaternion.view(-1, 4).contiguous(), dtype=wp.vec4, grad=wp_adj_quat),
wp.from_torch(points.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_points),
ctx.n,
ctx.b,
],
outputs=[
wp.from_torch(out_points.view(-1, 3), dtype=wp.vec3, grad=wp_adj_out_points),
],
adj_inputs=[
None,
None,
None,
ctx.n,
ctx.b,
],
adj_outputs=[
None,
],
stream=wp.stream_from_torch(grad_output.device),
adjoint=True,
)
g_p = g_q = g_pt = None
if ctx.needs_input_grad[0]:
g_p = adj_position
if ctx.needs_input_grad[1]:
g_q = adj_quaternion
if ctx.needs_input_grad[2]:
g_pt = adj_points
return g_p, g_q, g_pt, None, None, None, None
class BatchTransformPose(torch.autograd.Function):
@staticmethod
def forward(
ctx,
position: torch.Tensor,
quaternion: torch.Tensor,
position2: torch.Tensor,
quaternion2: torch.Tensor,
out_position: torch.Tensor,
out_quaternion: torch.Tensor,
adj_position: torch.Tensor,
adj_quaternion: torch.Tensor,
adj_position2: torch.Tensor,
adj_quaternion2: torch.Tensor,
):
b, _ = position.shape
if out_position is None:
out_position = torch.zeros_like(position2)
if out_quaternion is None:
out_quaternion = torch.zeros_like(quaternion2)
if adj_position is None:
adj_position = torch.zeros_like(position)
if adj_quaternion is None:
adj_quaternion = torch.zeros_like(quaternion)
if adj_position2 is None:
adj_position2 = torch.zeros_like(position2)
if adj_quaternion2 is None:
adj_quaternion2 = torch.zeros_like(quaternion2)
init_warp()
ctx.save_for_backward(
position,
quaternion,
position2,
quaternion2,
out_position,
out_quaternion,
adj_position,
adj_quaternion,
adj_position2,
adj_quaternion2,
)
ctx.b = b
wp.launch(
kernel=compute_batch_pose_multipy,
dim=b,
inputs=[
wp.from_torch(position.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
wp.from_torch(quaternion.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
wp.from_torch(position2.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
wp.from_torch(quaternion2.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
],
outputs=[
wp.from_torch(out_position.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
wp.from_torch(out_quaternion.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
],
stream=wp.stream_from_torch(position.device),
)
return out_position, out_quaternion
@staticmethod
def backward(ctx, grad_out_position, grad_out_quaternion):
(
position,
quaternion,
position2,
quaternion2,
out_position,
out_quaternion,
adj_position,
adj_quaternion,
adj_position2,
adj_quaternion2,
) = ctx.saved_tensors
init_warp()
wp_adj_out_position = wp.from_torch(
grad_out_position.view(-1, 3).contiguous(), dtype=wp.vec3
)
wp_adj_out_quaternion = wp.from_torch(
grad_out_quaternion.view(-1, 4).contiguous(), dtype=wp.vec4
)
adj_position = 0.0 * adj_position
adj_quaternion = 0.0 * adj_quaternion
adj_position2 = 0.0 * adj_position2
adj_quaternion2 = 0.0 * adj_quaternion2
wp_adj_position = wp.from_torch(adj_position.view(-1, 3), dtype=wp.vec3)
wp_adj_quat = wp.from_torch(adj_quaternion.view(-1, 4), dtype=wp.vec4)
wp_adj_position2 = wp.from_torch(adj_position2.view(-1, 3), dtype=wp.vec3)
wp_adj_quat2 = wp.from_torch(adj_quaternion2.view(-1, 4), dtype=wp.vec4)
wp.launch(
kernel=compute_batch_pose_multipy,
dim=ctx.b,
inputs=[
wp.from_torch(
position.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_position
),
wp.from_torch(quaternion.view(-1, 4).contiguous(), dtype=wp.vec4, grad=wp_adj_quat),
wp.from_torch(
position2.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_position2
),
wp.from_torch(
quaternion2.view(-1, 4).contiguous(), dtype=wp.vec4, grad=wp_adj_quat2
),
],
outputs=[
wp.from_torch(
out_position.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_out_position
),
wp.from_torch(
out_quaternion.view(-1, 4).contiguous(),
dtype=wp.vec4,
grad=wp_adj_out_quaternion,
),
],
adj_inputs=[
None,
None,
None,
None,
],
adj_outputs=[
None,
None,
],
stream=wp.stream_from_torch(grad_out_position.device),
adjoint=True,
)
g_p1 = g_q1 = g_p2 = g_q2 = None
if ctx.needs_input_grad[0]:
g_p1 = adj_position
if ctx.needs_input_grad[1]:
g_q1 = adj_quaternion
if ctx.needs_input_grad[2]:
g_p2 = adj_position2
if ctx.needs_input_grad[3]:
g_q2 = adj_quaternion2
return g_p1, g_q1, g_p2, g_q2, None, None, None, None
class TransformPose(torch.autograd.Function):
@staticmethod
def forward(
ctx,
position: torch.Tensor,
quaternion: torch.Tensor,
position2: torch.Tensor,
quaternion2: torch.Tensor,
out_position: torch.Tensor,
out_quaternion: torch.Tensor,
adj_position: torch.Tensor,
adj_quaternion: torch.Tensor,
adj_position2: torch.Tensor,
adj_quaternion2: torch.Tensor,
):
b, _ = position2.shape
init_warp()
if out_position is None:
out_position = torch.zeros_like(position2)
if out_quaternion is None:
out_quaternion = torch.zeros_like(quaternion2)
if adj_position is None:
adj_position = torch.zeros_like(position)
if adj_quaternion is None:
adj_quaternion = torch.zeros_like(quaternion)
if adj_position2 is None:
adj_position2 = torch.zeros_like(position2)
if adj_quaternion2 is None:
adj_quaternion2 = torch.zeros_like(quaternion2)
ctx.save_for_backward(
position,
quaternion,
position2,
quaternion2,
out_position,
out_quaternion,
adj_position,
adj_quaternion,
adj_position2,
adj_quaternion2,
)
ctx.b = b
wp.launch(
kernel=compute_batch_pose_multipy,
dim=b,
inputs=[
wp.from_torch(position.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
wp.from_torch(quaternion.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
wp.from_torch(position2.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
wp.from_torch(quaternion2.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
],
outputs=[
wp.from_torch(out_position.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
wp.from_torch(out_quaternion.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
],
stream=wp.stream_from_torch(position.device),
)
return out_position, out_quaternion
@staticmethod
def backward(ctx, grad_out_position, grad_out_quaternion):
(
position,
quaternion,
position2,
quaternion2,
out_position,
out_quaternion,
adj_position,
adj_quaternion,
adj_position2,
adj_quaternion2,
) = ctx.saved_tensors
init_warp()
wp_adj_out_position = wp.from_torch(
grad_out_position.view(-1, 3).contiguous(), dtype=wp.vec3
)
wp_adj_out_quaternion = wp.from_torch(
grad_out_quaternion.view(-1, 4).contiguous(), dtype=wp.vec4
)
adj_position = 0.0 * adj_position
adj_quaternion = 0.0 * adj_quaternion
adj_position2 = 0.0 * adj_position2
adj_quaternion2 = 0.0 * adj_quaternion2
wp_adj_position = wp.from_torch(adj_position.view(-1, 3), dtype=wp.vec3)
wp_adj_quat = wp.from_torch(adj_quaternion.view(-1, 4), dtype=wp.vec4)
wp_adj_position2 = wp.from_torch(adj_position2.view(-1, 3), dtype=wp.vec3)
wp_adj_quat2 = wp.from_torch(adj_quaternion2.view(-1, 4), dtype=wp.vec4)
wp.launch(
kernel=compute_batch_pose_multipy,
dim=ctx.b,
inputs=[
wp.from_torch(
position.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_position
),
wp.from_torch(quaternion.view(-1, 4).contiguous(), dtype=wp.vec4, grad=wp_adj_quat),
wp.from_torch(
position2.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_position2
),
wp.from_torch(
quaternion2.view(-1, 4).contiguous(), dtype=wp.vec4, grad=wp_adj_quat2
),
],
outputs=[
wp.from_torch(
out_position.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_out_position
),
wp.from_torch(
out_quaternion.view(-1, 4).contiguous(),
dtype=wp.vec4,
grad=wp_adj_out_quaternion,
),
],
adj_inputs=[
None,
None,
None,
None,
],
adj_outputs=[
None,
None,
],
stream=wp.stream_from_torch(grad_out_position.device),
adjoint=True,
)
g_p1 = g_q1 = g_p2 = g_q2 = None
if ctx.needs_input_grad[0]:
g_p1 = adj_position
if ctx.needs_input_grad[1]:
g_q1 = adj_quaternion
if ctx.needs_input_grad[2]:
g_p2 = adj_position2
if ctx.needs_input_grad[3]:
g_q2 = adj_quaternion2
return g_p1, g_q1, g_p2, g_q2, None, None, None, None
class PoseInverse(torch.autograd.Function):
@staticmethod
def forward(
ctx,
position: torch.Tensor,
quaternion: torch.Tensor,
out_position: torch.Tensor,
out_quaternion: torch.Tensor,
adj_position: torch.Tensor,
adj_quaternion: torch.Tensor,
):
b, _ = position.shape
if out_position is None:
out_position = torch.zeros_like(position)
if out_quaternion is None:
out_quaternion = torch.zeros_like(quaternion)
if adj_position is None:
adj_position = torch.zeros_like(position)
if adj_quaternion is None:
adj_quaternion = torch.zeros_like(quaternion)
init_warp()
ctx.save_for_backward(
position,
quaternion,
out_position,
out_quaternion,
adj_position,
adj_quaternion,
)
ctx.b = b
wp.launch(
kernel=compute_pose_inverse,
dim=b,
inputs=[
wp.from_torch(position.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
wp.from_torch(quaternion.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
],
outputs=[
wp.from_torch(out_position.detach().view(-1, 3).contiguous(), dtype=wp.vec3),
wp.from_torch(out_quaternion.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
],
stream=wp.stream_from_torch(position.device),
)
# remove close to zero values:
# out_position[torch.abs(out_position)<1e-8] = 0.0
# out_quaternion[torch.abs(out_quaternion)<1e-8] = 0.0
return out_position, out_quaternion
@staticmethod
def backward(ctx, grad_out_position, grad_out_quaternion):
(
position,
quaternion,
out_position,
out_quaternion,
adj_position,
adj_quaternion,
) = ctx.saved_tensors
init_warp()
wp_adj_out_position = wp.from_torch(
grad_out_position.view(-1, 3).contiguous(), dtype=wp.vec3
)
wp_adj_out_quaternion = wp.from_torch(
grad_out_quaternion.view(-1, 4).contiguous(), dtype=wp.vec4
)
adj_position = 0.0 * adj_position
adj_quaternion = 0.0 * adj_quaternion
wp_adj_position = wp.from_torch(adj_position.view(-1, 3), dtype=wp.vec3)
wp_adj_quat = wp.from_torch(adj_quaternion.view(-1, 4), dtype=wp.vec4)
wp.launch(
kernel=compute_pose_inverse,
dim=ctx.b,
inputs=[
wp.from_torch(
position.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_position
),
wp.from_torch(quaternion.view(-1, 4).contiguous(), dtype=wp.vec4, grad=wp_adj_quat),
],
outputs=[
wp.from_torch(
out_position.view(-1, 3).contiguous(), dtype=wp.vec3, grad=wp_adj_out_position
),
wp.from_torch(
out_quaternion.view(-1, 4).contiguous(),
dtype=wp.vec4,
grad=wp_adj_out_quaternion,
),
],
adj_inputs=[
None,
None,
None,
None,
],
adj_outputs=[
None,
None,
],
stream=wp.stream_from_torch(grad_out_position.device),
adjoint=True,
)
g_p1 = g_q1 = None
if ctx.needs_input_grad[0]:
g_p1 = adj_position
if ctx.needs_input_grad[1]:
g_q1 = adj_quaternion
return g_p1, g_q1, None, None
class QuatToMatrix(torch.autograd.Function):
@staticmethod
def forward(
ctx,
quaternion: torch.Tensor,
out_mat: torch.Tensor,
adj_quaternion: torch.Tensor,
):
b, _ = quaternion.shape
if out_mat is None:
out_mat = torch.zeros(
(quaternion.shape[0], 3, 3), device=quaternion.device, dtype=quaternion.dtype
)
if adj_quaternion is None:
adj_quaternion = torch.zeros_like(quaternion)
init_warp()
ctx.save_for_backward(
quaternion,
out_mat,
adj_quaternion,
)
ctx.b = b
wp.launch(
kernel=compute_quat_to_matrix,
dim=b,
inputs=[
wp.from_torch(quaternion.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
],
outputs=[
wp.from_torch(out_mat.detach().view(-1, 3, 3).contiguous(), dtype=wp.mat33),
],
stream=wp.stream_from_torch(quaternion.device),
)
return out_mat
@staticmethod
def backward(ctx, grad_out_mat):
(
quaternion,
out_mat,
adj_quaternion,
) = ctx.saved_tensors
init_warp()
wp_adj_out_mat = wp.from_torch(grad_out_mat.view(-1, 3, 3).contiguous(), dtype=wp.mat33)
adj_quaternion = 0.0 * adj_quaternion
wp_adj_quat = wp.from_torch(adj_quaternion.view(-1, 4), dtype=wp.vec4)
wp.launch(
kernel=compute_quat_to_matrix,
dim=ctx.b,
inputs=[
wp.from_torch(quaternion.view(-1, 4).contiguous(), dtype=wp.vec4, grad=wp_adj_quat),
],
outputs=[
wp.from_torch(
out_mat.view(-1, 3, 3).contiguous(),
dtype=wp.mat33,
grad=wp_adj_out_mat,
),
],
adj_inputs=[
None,
],
adj_outputs=[
None,
],
stream=wp.stream_from_torch(grad_out_mat.device),
adjoint=True,
)
g_q1 = None
if ctx.needs_input_grad[0]:
g_q1 = adj_quaternion
return g_q1, None, None
class MatrixToQuaternion(torch.autograd.Function):
@staticmethod
def forward(
ctx,
in_mat: torch.Tensor,
out_quaternion: torch.Tensor,
adj_mat: torch.Tensor,
):
b, _, _ = in_mat.shape
if out_quaternion is None:
out_quaternion = torch.zeros(
(in_mat.shape[0], 4), device=in_mat.device, dtype=in_mat.dtype
)
if adj_mat is None:
adj_mat = torch.zeros_like(in_mat)
init_warp()
ctx.save_for_backward(
in_mat,
out_quaternion,
adj_mat,
)
ctx.b = b
wp.launch(
kernel=compute_matrix_to_quat,
dim=b,
inputs=[
wp.from_torch(in_mat.detach().view(-1, 3, 3).contiguous(), dtype=wp.mat33),
],
outputs=[
wp.from_torch(out_quaternion.detach().view(-1, 4).contiguous(), dtype=wp.vec4),
],
stream=wp.stream_from_torch(in_mat.device),
)
return out_quaternion
@staticmethod
def backward(ctx, grad_out_q):
(
in_mat,
out_quaternion,
adj_mat,
) = ctx.saved_tensors
init_warp()
wp_adj_out_q = wp.from_torch(grad_out_q.view(-1, 4).contiguous(), dtype=wp.vec4)
adj_mat = 0.0 * adj_mat
wp_adj_mat = wp.from_torch(adj_mat.view(-1, 3, 3), dtype=wp.mat33)
wp.launch(
kernel=compute_matrix_to_quat,
dim=ctx.b,
inputs=[
wp.from_torch(in_mat.view(-1, 3, 3).contiguous(), dtype=wp.mat33, grad=wp_adj_mat),
],
outputs=[
wp.from_torch(
out_quaternion.view(-1, 4).contiguous(), dtype=wp.vec4, grad=wp_adj_out_q
),
],
adj_inputs=[
None,
],
adj_outputs=[
None,
],
stream=wp.stream_from_torch(grad_out_q.device),
adjoint=True,
)
g_q1 = None
if ctx.needs_input_grad[0]:
g_q1 = adj_mat
return g_q1, None, None