737 lines
18 KiB
Python
737 lines
18 KiB
Python
#
|
|
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
# property and proprietary rights in and to this material, related
|
|
# documentation and any modifications thereto. Any use, reproduction,
|
|
# disclosure or distribution of this material and related documentation
|
|
# without an express license agreement from NVIDIA CORPORATION or
|
|
# its affiliates is strictly prohibited.
|
|
#
|
|
# Third Party
|
|
import torch
|
|
|
|
# CuRobo
|
|
from curobo.util.logger import log_warn
|
|
|
|
try:
|
|
# CuRobo
|
|
from curobo.curobolib import geom_cu
|
|
|
|
except ImportError:
|
|
log_warn("geom_cu binary not found, jit compiling...")
|
|
# Third Party
|
|
from torch.utils.cpp_extension import load
|
|
|
|
# CuRobo
|
|
from curobo.util_file import add_cpp_path
|
|
|
|
geom_cu = load(
|
|
name="geom_cu",
|
|
sources=add_cpp_path(
|
|
[
|
|
"geom_cuda.cpp",
|
|
"sphere_obb_kernel.cu",
|
|
"pose_distance_kernel.cu",
|
|
"self_collision_kernel.cu",
|
|
]
|
|
),
|
|
)
|
|
|
|
|
|
def get_self_collision_distance(
|
|
out_distance,
|
|
out_vec,
|
|
sparse_index,
|
|
robot_spheres,
|
|
collision_offset,
|
|
weight,
|
|
coll_matrix,
|
|
thread_locations,
|
|
thread_size,
|
|
b_size,
|
|
nspheres,
|
|
compute_grad,
|
|
checks_per_thread=32,
|
|
experimental_kernel=True,
|
|
):
|
|
r = geom_cu.self_collision_distance(
|
|
out_distance,
|
|
out_vec,
|
|
sparse_index,
|
|
robot_spheres,
|
|
collision_offset,
|
|
weight,
|
|
coll_matrix,
|
|
thread_locations,
|
|
thread_size,
|
|
b_size,
|
|
nspheres,
|
|
compute_grad,
|
|
checks_per_thread,
|
|
experimental_kernel,
|
|
)
|
|
|
|
out_distance = r[0]
|
|
out_vec = r[1]
|
|
return out_distance, out_vec
|
|
|
|
|
|
class SelfCollisionDistance(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
out_distance,
|
|
out_vec,
|
|
sparse_idx,
|
|
robot_spheres,
|
|
sphere_offset,
|
|
weight,
|
|
coll_matrix,
|
|
thread_locations,
|
|
max_thread,
|
|
checks_per_thread: int,
|
|
experimental_kernel: bool,
|
|
return_loss: bool = False,
|
|
):
|
|
# get batch size
|
|
b, h, n_spheres, _ = robot_spheres.shape
|
|
out_distance, out_vec = get_self_collision_distance(
|
|
out_distance,
|
|
out_vec,
|
|
sparse_idx,
|
|
robot_spheres, # .view(-1, 4),
|
|
sphere_offset,
|
|
weight,
|
|
coll_matrix,
|
|
thread_locations,
|
|
max_thread,
|
|
b * h,
|
|
n_spheres,
|
|
robot_spheres.requires_grad,
|
|
checks_per_thread,
|
|
experimental_kernel,
|
|
)
|
|
ctx.return_loss = return_loss
|
|
ctx.save_for_backward(out_vec)
|
|
return out_distance
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out_distance):
|
|
sphere_grad = None
|
|
if ctx.needs_input_grad[3]:
|
|
(g_vec,) = ctx.saved_tensors
|
|
if ctx.return_loss:
|
|
g_vec = g_vec * grad_out_distance.unsqueeze(1)
|
|
sphere_grad = g_vec
|
|
return None, None, None, sphere_grad, None, None, None, None, None, None, None, None
|
|
|
|
|
|
class SelfCollisionDistanceLoss(SelfCollisionDistance):
|
|
@staticmethod
|
|
def backward(ctx, grad_out_distance):
|
|
sphere_grad = None
|
|
if ctx.needs_input_grad[3]:
|
|
(g_vec,) = ctx.saved_tensors
|
|
sphere_grad = g_vec * grad_out_distance.unsqueeze(1)
|
|
return None, None, None, sphere_grad, None, None, None, None, None, None, None
|
|
|
|
|
|
def get_pose_distance(
|
|
out_distance,
|
|
out_position_distance,
|
|
out_rotation_distance,
|
|
out_p_vec,
|
|
out_q_vec,
|
|
out_idx,
|
|
current_position,
|
|
goal_position,
|
|
current_quat,
|
|
goal_quat,
|
|
vec_weight,
|
|
weight,
|
|
vec_convergence,
|
|
run_weight,
|
|
run_vec_weight,
|
|
offset_waypoint,
|
|
offset_tstep_fraction,
|
|
batch_pose_idx,
|
|
batch_size,
|
|
horizon,
|
|
mode=1,
|
|
num_goals=1,
|
|
write_grad=False,
|
|
write_distance=False,
|
|
use_metric=False,
|
|
project_distance=True,
|
|
):
|
|
if batch_pose_idx.shape[0] != batch_size:
|
|
raise ValueError("Index buffer size is different from batch size")
|
|
|
|
r = geom_cu.pose_distance(
|
|
out_distance,
|
|
out_position_distance,
|
|
out_rotation_distance,
|
|
out_p_vec,
|
|
out_q_vec,
|
|
out_idx,
|
|
current_position,
|
|
goal_position.view(-1),
|
|
current_quat,
|
|
goal_quat.view(-1),
|
|
vec_weight,
|
|
weight,
|
|
vec_convergence,
|
|
run_weight,
|
|
run_vec_weight,
|
|
offset_waypoint,
|
|
offset_tstep_fraction,
|
|
batch_pose_idx,
|
|
batch_size,
|
|
horizon,
|
|
mode,
|
|
num_goals,
|
|
write_grad,
|
|
write_distance,
|
|
use_metric,
|
|
project_distance,
|
|
)
|
|
|
|
out_distance = r[0]
|
|
out_position_distance = r[1]
|
|
out_rotation_distance = r[2]
|
|
|
|
out_p_vec = r[3]
|
|
out_q_vec = r[4]
|
|
|
|
out_idx = r[5]
|
|
return out_distance, out_position_distance, out_rotation_distance, out_p_vec, out_q_vec, out_idx
|
|
|
|
|
|
def get_pose_distance_backward(
|
|
out_grad_p,
|
|
out_grad_q,
|
|
grad_distance,
|
|
grad_p_distance,
|
|
grad_q_distance,
|
|
pose_weight,
|
|
grad_p_vec,
|
|
grad_q_vec,
|
|
batch_size,
|
|
use_distance=False,
|
|
):
|
|
r = geom_cu.pose_distance_backward(
|
|
out_grad_p,
|
|
out_grad_q,
|
|
grad_distance,
|
|
grad_p_distance,
|
|
grad_q_distance,
|
|
pose_weight,
|
|
grad_p_vec,
|
|
grad_q_vec,
|
|
batch_size,
|
|
use_distance,
|
|
)
|
|
return r[0], r[1]
|
|
|
|
|
|
@torch.jit.script
|
|
def backward_PoseError_jit(grad_g_dist, grad_out_distance, weight, g_vec):
|
|
grad_vec = grad_g_dist + (grad_out_distance * weight)
|
|
grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec
|
|
return grad
|
|
|
|
|
|
# full method:
|
|
@torch.jit.script
|
|
def backward_full_PoseError_jit(
|
|
grad_out_distance, grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
|
|
):
|
|
p_grad = (grad_g_dist + (grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
|
|
q_grad = (grad_r_err + (grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
|
|
# p_grad = ((grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
|
|
# q_grad = ((grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
|
|
|
|
return p_grad, q_grad
|
|
|
|
|
|
class PoseErrorDistance(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
current_position,
|
|
goal_position,
|
|
current_quat,
|
|
goal_quat,
|
|
vec_weight,
|
|
weight,
|
|
vec_convergence,
|
|
run_weight,
|
|
run_vec_weight,
|
|
offset_waypoint,
|
|
offset_tstep_fraction,
|
|
batch_pose_idx,
|
|
out_distance,
|
|
out_position_distance,
|
|
out_rotation_distance,
|
|
out_p_vec,
|
|
out_r_vec,
|
|
out_idx,
|
|
out_p_grad,
|
|
out_q_grad,
|
|
batch_size,
|
|
horizon,
|
|
mode, # =PoseErrorType.BATCH_GOAL.value,
|
|
num_goals,
|
|
use_metric, # =False,
|
|
project_distance, # =True,
|
|
):
|
|
# out_distance = current_position[..., 0].detach().clone() * 0.0
|
|
# out_position_distance = out_distance.detach().clone()
|
|
# out_rotation_distance = out_distance.detach().clone()
|
|
# out_vec = (
|
|
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
|
|
# * 0.0
|
|
# )
|
|
# out_idx = out_distance.clone().to(dtype=torch.long)
|
|
|
|
(
|
|
out_distance,
|
|
out_position_distance,
|
|
out_rotation_distance,
|
|
out_p_vec,
|
|
out_r_vec,
|
|
out_idx,
|
|
) = get_pose_distance(
|
|
out_distance,
|
|
out_position_distance,
|
|
out_rotation_distance,
|
|
out_p_vec,
|
|
out_r_vec,
|
|
out_idx,
|
|
current_position.contiguous(),
|
|
goal_position,
|
|
current_quat.contiguous(),
|
|
goal_quat,
|
|
vec_weight,
|
|
weight,
|
|
vec_convergence,
|
|
run_weight,
|
|
run_vec_weight,
|
|
offset_waypoint,
|
|
offset_tstep_fraction,
|
|
batch_pose_idx,
|
|
batch_size,
|
|
horizon,
|
|
mode,
|
|
num_goals,
|
|
current_position.requires_grad,
|
|
True,
|
|
use_metric,
|
|
project_distance,
|
|
)
|
|
ctx.save_for_backward(out_p_vec, out_r_vec, weight, out_p_grad, out_q_grad)
|
|
return out_distance, out_position_distance, out_rotation_distance, out_idx # .view(-1,1)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out_distance, grad_g_dist, grad_r_err, grad_out_idx):
|
|
(g_vec_p, g_vec_q, weight, out_grad_p, out_grad_q) = ctx.saved_tensors
|
|
pos_grad = None
|
|
quat_grad = None
|
|
batch_size = g_vec_p.shape[0] * g_vec_p.shape[1]
|
|
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
|
|
pos_grad, quat_grad = get_pose_distance_backward(
|
|
out_grad_p,
|
|
out_grad_q,
|
|
grad_out_distance.contiguous(),
|
|
grad_g_dist.contiguous(),
|
|
grad_r_err.contiguous(),
|
|
weight,
|
|
g_vec_p,
|
|
g_vec_q,
|
|
batch_size,
|
|
use_distance=True,
|
|
)
|
|
|
|
elif ctx.needs_input_grad[0]:
|
|
pos_grad = backward_PoseError_jit(grad_g_dist, grad_out_distance, weight[1], g_vec_p)
|
|
|
|
elif ctx.needs_input_grad[2]:
|
|
quat_grad = backward_PoseError_jit(grad_r_err, grad_out_distance, weight[0], g_vec_q)
|
|
|
|
return (
|
|
pos_grad,
|
|
None,
|
|
quat_grad,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
|
|
|
|
class PoseError(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
current_position: torch.Tensor,
|
|
goal_position: torch.Tensor,
|
|
current_quat: torch.Tensor,
|
|
goal_quat,
|
|
vec_weight,
|
|
weight,
|
|
vec_convergence,
|
|
run_weight,
|
|
run_vec_weight,
|
|
offset_waypoint,
|
|
offset_tstep_fraction,
|
|
batch_pose_idx,
|
|
out_distance,
|
|
out_position_distance,
|
|
out_rotation_distance,
|
|
out_p_vec,
|
|
out_r_vec,
|
|
out_idx,
|
|
out_p_grad,
|
|
out_q_grad,
|
|
batch_size,
|
|
horizon,
|
|
mode,
|
|
num_goals,
|
|
use_metric,
|
|
project_distance,
|
|
return_loss,
|
|
):
|
|
"""Compute error in pose
|
|
|
|
_extended_summary_
|
|
|
|
Args:
|
|
ctx: _description_
|
|
current_position: _description_
|
|
goal_position: _description_
|
|
current_quat: _description_
|
|
goal_quat: _description_
|
|
vec_weight: _description_
|
|
weight: _description_
|
|
vec_convergence: _description_
|
|
run_weight: _description_
|
|
run_vec_weight: _description_
|
|
offset_waypoint: _description_
|
|
offset_tstep_fraction: _description_
|
|
batch_pose_idx: _description_
|
|
out_distance: _description_
|
|
out_position_distance: _description_
|
|
out_rotation_distance: _description_
|
|
out_p_vec: _description_
|
|
out_r_vec: _description_
|
|
out_idx: _description_
|
|
out_p_grad: _description_
|
|
out_q_grad: _description_
|
|
batch_size: _description_
|
|
horizon: _description_
|
|
mode: _description_
|
|
num_goals: _description_
|
|
use_metric: _description_
|
|
project_distance: _description_
|
|
return_loss: _description_
|
|
|
|
Returns:
|
|
_description_
|
|
"""
|
|
# out_distance = current_position[..., 0].detach().clone() * 0.0
|
|
# out_position_distance = out_distance.detach().clone()
|
|
# out_rotation_distance = out_distance.detach().clone()
|
|
# out_vec = (
|
|
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
|
|
# * 0.0
|
|
# )
|
|
# out_idx = out_distance.clone().to(dtype=torch.long)
|
|
ctx.return_loss = return_loss
|
|
(
|
|
out_distance,
|
|
out_position_distance,
|
|
out_rotation_distance,
|
|
out_p_vec,
|
|
out_r_vec,
|
|
out_idx,
|
|
) = get_pose_distance(
|
|
out_distance,
|
|
out_position_distance,
|
|
out_rotation_distance,
|
|
out_p_vec,
|
|
out_r_vec,
|
|
out_idx,
|
|
current_position.contiguous(),
|
|
goal_position,
|
|
current_quat.contiguous(),
|
|
goal_quat,
|
|
vec_weight,
|
|
weight,
|
|
vec_convergence,
|
|
run_weight,
|
|
run_vec_weight,
|
|
offset_waypoint,
|
|
offset_tstep_fraction,
|
|
batch_pose_idx,
|
|
batch_size,
|
|
horizon,
|
|
mode,
|
|
num_goals,
|
|
current_position.requires_grad,
|
|
False,
|
|
use_metric,
|
|
project_distance,
|
|
)
|
|
ctx.save_for_backward(out_p_vec, out_r_vec)
|
|
return out_distance
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out_distance): # , grad_g_dist, grad_r_err, grad_out_idx):
|
|
pos_grad = None
|
|
quat_grad = None
|
|
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
|
|
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
|
pos_grad = g_vec_p
|
|
quat_grad = g_vec_q
|
|
if ctx.return_loss:
|
|
pos_grad = pos_grad * grad_out_distance.unsqueeze(1)
|
|
quat_grad = quat_grad * grad_out_distance.unsqueeze(1)
|
|
|
|
elif ctx.needs_input_grad[0]:
|
|
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
|
|
|
pos_grad = g_vec_p
|
|
if ctx.return_loss:
|
|
pos_grad = pos_grad * grad_out_distance.unsqueeze(1)
|
|
elif ctx.needs_input_grad[2]:
|
|
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
|
|
|
quat_grad = g_vec_q
|
|
if ctx.return_loss:
|
|
quat_grad = quat_grad * grad_out_distance.unsqueeze(1)
|
|
return (
|
|
pos_grad,
|
|
None,
|
|
quat_grad,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
|
|
|
|
class SdfSphereOBB(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
query_sphere,
|
|
out_buffer,
|
|
grad_out_buffer,
|
|
sparsity_idx,
|
|
weight,
|
|
activation_distance,
|
|
box_accel,
|
|
box_dims,
|
|
box_pose,
|
|
box_enable,
|
|
n_env_obb,
|
|
env_query_idx,
|
|
max_nobs,
|
|
batch_size,
|
|
horizon,
|
|
n_spheres,
|
|
transform_back,
|
|
compute_distance,
|
|
use_batch_env,
|
|
return_loss: bool = False,
|
|
):
|
|
r = geom_cu.closest_point(
|
|
query_sphere,
|
|
out_buffer,
|
|
grad_out_buffer,
|
|
sparsity_idx,
|
|
weight,
|
|
activation_distance,
|
|
box_accel,
|
|
box_dims,
|
|
box_pose,
|
|
box_enable,
|
|
n_env_obb,
|
|
env_query_idx,
|
|
max_nobs,
|
|
batch_size,
|
|
horizon,
|
|
n_spheres,
|
|
transform_back,
|
|
compute_distance,
|
|
use_batch_env,
|
|
)
|
|
# r[1][r[1]!=r[1]] = 0.0
|
|
ctx.return_loss = return_loss
|
|
ctx.save_for_backward(r[1])
|
|
return r[0]
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
grad_pt = None
|
|
if ctx.needs_input_grad[0]:
|
|
(r,) = ctx.saved_tensors
|
|
if ctx.return_loss:
|
|
r = r * grad_output.unsqueeze(-1)
|
|
grad_pt = r
|
|
return (
|
|
grad_pt,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
|
|
|
|
class SdfSweptSphereOBB(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
query_sphere,
|
|
out_buffer,
|
|
grad_out_buffer,
|
|
sparsity_idx,
|
|
weight,
|
|
activation_distance,
|
|
speed_dt,
|
|
box_accel,
|
|
box_dims,
|
|
box_pose,
|
|
box_enable,
|
|
n_env_obb,
|
|
env_query_idx,
|
|
max_nobs,
|
|
batch_size,
|
|
horizon,
|
|
n_spheres,
|
|
sweep_steps,
|
|
enable_speed_metric,
|
|
transform_back,
|
|
compute_distance,
|
|
use_batch_env,
|
|
return_loss: bool = False,
|
|
):
|
|
r = geom_cu.swept_closest_point(
|
|
query_sphere,
|
|
out_buffer,
|
|
grad_out_buffer,
|
|
sparsity_idx,
|
|
weight,
|
|
activation_distance,
|
|
speed_dt,
|
|
box_accel,
|
|
box_dims,
|
|
box_pose,
|
|
box_enable,
|
|
n_env_obb,
|
|
env_query_idx,
|
|
max_nobs,
|
|
batch_size,
|
|
horizon,
|
|
n_spheres,
|
|
sweep_steps,
|
|
enable_speed_metric,
|
|
transform_back,
|
|
compute_distance,
|
|
use_batch_env,
|
|
)
|
|
ctx.return_loss = return_loss
|
|
ctx.save_for_backward(
|
|
r[1],
|
|
)
|
|
return r[0]
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
grad_pt = None
|
|
if ctx.needs_input_grad[0]:
|
|
(r,) = ctx.saved_tensors
|
|
if ctx.return_loss:
|
|
r = r * grad_output.unsqueeze(-1)
|
|
grad_pt = r
|
|
return (
|
|
grad_pt,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|