Add support for older warp versions (<1.0.0)

This commit is contained in:
Balakumar Sundaralingam
2024-06-24 15:06:25 -07:00
parent 0c51dd2da8
commit ec527d77e9
14 changed files with 1101 additions and 512 deletions

View File

@@ -10,6 +10,14 @@ its affiliates is strictly prohibited.
--> -->
# Changelog # Changelog
## Latest Commit
### BugFixes & Misc.
- Add support for older warp versions (<1.0.0) as it's not possible to run older isaac sim with
newer warp versions.
- Add override option to mpc dataclass.
- Fix bug in ``PoseCost.forward_pose()`` which caused ``torch_layers_example.py`` to fail.
## Version 0.7.3 ## Version 0.7.3
### New Features ### New Features

View File

@@ -19,7 +19,11 @@ Use [Discussions](https://github.com/NVlabs/curobo/discussions) for questions on
Use [Issues](https://github.com/NVlabs/curobo/issues) if you find a bug. Use [Issues](https://github.com/NVlabs/curobo/issues) if you find a bug.
For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/) cuRobo's collision-free motion planner is available for commercial applications as a
MoveIt plugin: [Isaac ROS cuMotion](https://github.com/NVIDIA-ISAAC-ROS/isaac_ros_cumotion)
For business inquiries of this python library, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/)
## Overview ## Overview
@@ -44,9 +48,9 @@ If you found this work useful, please cite the below report,
``` ```
@misc{curobo_report23, @misc{curobo_report23,
title={cuRobo: Parallelized Collision-Free Minimum-Jerk Robot Motion Generation}, title={cuRobo: Parallelized Collision-Free Minimum-Jerk Robot Motion Generation},
author={Balakumar Sundaralingam and Siva Kumar Sastry Hari and Adam Fishman and Caelan Garrett author={Balakumar Sundaralingam and Siva Kumar Sastry Hari and Adam Fishman and Caelan Garrett
and Karl Van Wyk and Valts Blukis and Alexander Millane and Helen Oleynikova and Ankur Handa and Karl Van Wyk and Valts Blukis and Alexander Millane and Helen Oleynikova and Ankur Handa
and Fabio Ramos and Nathan Ratliff and Dieter Fox}, and Fabio Ramos and Nathan Ratliff and Dieter Fox},
year={2023}, year={2023},
eprint={2310.17274}, eprint={2310.17274},

View File

@@ -85,10 +85,13 @@ class CuroboTorch(torch.nn.Module):
kin_state.ee_quaternion, kin_state.ee_quaternion,
] ]
if x_des is not None: if x_des is not None:
pose_distance = self._robot_world.pose_distance(x_des, kin_state.ee_pose) pose_distance = self._robot_world.pose_distance(
x_des, kin_state.ee_pose, resize=True
).view(-1, 1)
features.append(pose_distance) features.append(pose_distance)
features.append(x_des.position) features.append(x_des.position)
features.append(x_des.quaternion) features.append(x_des.quaternion)
features = torch.cat(features, dim=-1) features = torch.cat(features, dim=-1)
return features return features
@@ -114,17 +117,25 @@ class CuroboTorch(torch.nn.Module):
def loss(self, x_des: Pose, q: torch.Tensor, q_in: torch.Tensor): def loss(self, x_des: Pose, q: torch.Tensor, q_in: torch.Tensor):
kin_state = self._robot_world.get_kinematics(q) kin_state = self._robot_world.get_kinematics(q)
distance = self._robot_world.pose_distance(x_des, kin_state.ee_pose) distance = self._robot_world.pose_distance(x_des, kin_state.ee_pose, resize=True)
d_sdf = self._robot_world.collision_constraint(kin_state.link_spheres_tensor.unsqueeze(1)) d_sdf = self._robot_world.collision_constraint(
d_self = self._robot_world.self_collision_cost(kin_state.link_spheres_tensor.unsqueeze(1)) kin_state.link_spheres_tensor.unsqueeze(1)
).view(-1)
d_self = self._robot_world.self_collision_cost(
kin_state.link_spheres_tensor.unsqueeze(1)
).view(-1)
loss = 0.1 * torch.linalg.norm(q_in - q, dim=-1) + distance + 100.0 * (d_self + d_sdf) loss = 0.1 * torch.linalg.norm(q_in - q, dim=-1) + distance + 100.0 * (d_self + d_sdf)
return loss return loss
def val_loss(self, x_des: Pose, q: torch.Tensor, q_in: torch.Tensor): def val_loss(self, x_des: Pose, q: torch.Tensor, q_in: torch.Tensor):
kin_state = self._robot_world.get_kinematics(q) kin_state = self._robot_world.get_kinematics(q)
distance = self._robot_world.pose_distance(x_des, kin_state.ee_pose) distance = self._robot_world.pose_distance(x_des, kin_state.ee_pose, resize=True)
d_sdf = self._robot_world.collision_constraint(kin_state.link_spheres_tensor.unsqueeze(1)) d_sdf = self._robot_world.collision_constraint(
d_self = self._robot_world.self_collision_cost(kin_state.link_spheres_tensor.unsqueeze(1)) kin_state.link_spheres_tensor.unsqueeze(1)
).view(-1)
d_self = self._robot_world.self_collision_cost(
kin_state.link_spheres_tensor.unsqueeze(1)
).view(-1)
loss = 10.0 * (d_self + d_sdf) + distance loss = 10.0 * (d_self + d_sdf) + distance
return loss return loss

View File

@@ -50,7 +50,7 @@ install_requires =
torch>=1.10 torch>=1.10
trimesh trimesh
yourdfpy>=0.0.53 yourdfpy>=0.0.53
warp-lang>=0.11.0 warp-lang>=0.9.0
scipy>=1.7.0 scipy>=1.7.0
tqdm tqdm
wheel wheel

View File

@@ -15,499 +15,16 @@ import warp as wp
wp.set_module_options({"fast_math": False}) wp.set_module_options({"fast_math": False})
# CuRobo
from curobo.util.warp import warp_support_sdf_struct
@wp.func # Check version of warp and import the supported SDF function.
def mesh_query_point_fn( if warp_support_sdf_struct():
idx: wp.uint64, # Local Folder
point: wp.vec3, from .warp_sdf_fns import get_closest_pt_batch_env, get_swept_closest_pt_batch_env
max_distance: float, else:
): # Local Folder
collide_result = wp.mesh_query_point(idx, point, max_distance) from .warp_sdf_fns_deprecated import get_closest_pt_batch_env, get_swept_closest_pt_batch_env
return collide_result
@wp.kernel
def get_swept_closest_pt_batch_env(
pt: wp.array(dtype=wp.vec4),
distance: wp.array(dtype=wp.float32), # this stores the output cost
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
sparsity_idx: wp.array(dtype=wp.uint8),
weight: wp.array(dtype=wp.float32),
activation_distance: wp.array(dtype=wp.float32), # eta threshold
speed_dt: wp.array(dtype=wp.float32),
mesh: wp.array(dtype=wp.uint64),
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.array(dtype=wp.float32),
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
sweep_steps: wp.uint8,
enable_speed_metric: wp.uint8,
env_query_idx: wp.array(dtype=wp.int32),
use_batch_env: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
# distance is negative outside and positive inside
tid = int(0)
tid = wp.tid()
b_idx = int(0)
h_idx = int(0)
sph_idx = int(0)
# read horizon
b_idx = tid / (horizon * nspheres)
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
return
uint_zero = wp.uint8(0)
uint_one = wp.uint8(1)
env_idx = int(0)
n_mesh = int(0)
# $wp.printf("%d, %d, %d, %d \n", tid, b_idx, h_idx, sph_idx)
# read sphere
sphere_0_distance = float(0.0)
sphere_2_distance = float(0.0)
sphere_0 = wp.vec3(0.0)
sphere_2 = wp.vec3(0.0)
sphere_int = wp.vec3(0.0)
sphere_temp = wp.vec3(0.0)
grad_vec = wp.vec3(0.0)
eta = float(0.0)
dt = float(0.0)
k0 = float(0.0)
sign = float(0.0)
dist = float(0.0)
dist_metric = float(0.0)
euclidean_distance = float(0.0)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
in_rad = in_sphere[3]
if in_rad < 0.0:
distance[tid] = 0.0
if write_grad == 1 and sparsity_idx[tid] == uint_one:
sparsity_idx[tid] = uint_zero
closest_pt[tid * 4] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
dt = speed_dt[0]
eta = activation_distance[0]
in_rad += eta
max_dist_buffer = float(0.0)
max_dist_buffer = max_dist[0]
if (in_rad) > max_dist_buffer:
max_dist_buffer += in_rad
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
# read in sphere 0:
if h_idx > 0:
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx - 1) * nspheres) + sph_idx]
sphere_0 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
sphere_0_distance = wp.length(sphere_0 - in_pt) / 2.0
if h_idx < horizon - 1:
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx + 1) * nspheres) + sph_idx]
sphere_2 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
sphere_2_distance = wp.length(sphere_2 - in_pt) / 2.0
# read in sphere 2:
closest_distance = float(0.0)
closest_point = wp.vec3(0.0)
i = int(0)
dis_length = float(0.0)
jump_distance = float(0.0)
mid_distance = float(0.0)
if use_batch_env:
env_idx = env_query_idx[b_idx]
i = max_nmesh * env_idx
n_mesh = i + n_env_mesh[env_idx]
obj_position = wp.vec3()
while i < n_mesh:
if mesh_enable[i] == uint_one:
# transform point to mesh frame:
# mesh_pt = T_inverse @ w_pt
obj_position[0] = mesh_pose[i * 8 + 0]
obj_position[1] = mesh_pose[i * 8 + 1]
obj_position[2] = mesh_pose[i * 8 + 2]
obj_quat = wp.quaternion(
mesh_pose[i * 8 + 4],
mesh_pose[i * 8 + 5],
mesh_pose[i * 8 + 6],
mesh_pose[i * 8 + 3],
)
obj_w_pose = wp.transform(obj_position, obj_quat)
obj_w_pose_t = wp.transform_inverse(obj_w_pose)
local_pt = wp.transform_point(obj_w_pose, in_pt)
collide_result = mesh_query_point_fn(mesh[i], local_pt, max_dist_buffer)
if collide_result.result:
sign = collide_result.sign
cl_pt = wp.mesh_eval_position(
mesh[i], collide_result.face, collide_result.u, collide_result.v
)
delta = cl_pt - local_pt
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
euclidean_distance = dist
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_distance += dist_metric
closest_point += grad_vec
else:
dist = -1.0 * dist
euclidean_distance = dist
else:
dist = max_dist_buffer
euclidean_distance = dist
dist = max(euclidean_distance - in_rad, in_rad)
mid_distance = euclidean_distance
# transform sphere -1
if h_idx > 0 and mid_distance < sphere_0_distance:
jump_distance = mid_distance
j = int(0)
sphere_temp = wp.transform_point(obj_w_pose, sphere_0)
while j < sweep_steps:
k0 = (
1.0 - 0.5 * jump_distance / sphere_0_distance
) # dist could be greater than sphere_0_distance here?
sphere_int = k0 * local_pt + ((1.0 - k0) * sphere_temp)
collide_result = mesh_query_point_fn(mesh[i], sphere_int, max_dist_buffer)
if collide_result.result:
sign = collide_result.sign
cl_pt = wp.mesh_eval_position(
mesh[i], collide_result.face, collide_result.u, collide_result.v
)
delta = cl_pt - sphere_int
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
euclidean_distance = dist
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
closest_distance += dist_metric
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_point += grad_vec
dist = max(euclidean_distance - in_rad, in_rad)
jump_distance += euclidean_distance
else:
dist = max(-dist - in_rad, in_rad)
euclidean_distance = dist
jump_distance += euclidean_distance
else:
jump_distance += max_dist_buffer
j += 1
if jump_distance >= sphere_0_distance:
j = int(sweep_steps)
# transform sphere -1
if h_idx < horizon - 1 and mid_distance < sphere_2_distance:
jump_distance = mid_distance
j = int(0)
sphere_temp = wp.transform_point(obj_w_pose, sphere_2)
while j < sweep_steps:
k0 = (
1.0 - 0.5 * jump_distance / sphere_2_distance
) # dist could be greater than sphere_0_distance here?
sphere_int = k0 * local_pt + (1.0 - k0) * sphere_temp
collide_result = mesh_query_point_fn(mesh[i], sphere_int, max_dist_buffer)
if collide_result.result:
sign = collide_result.sign
cl_pt = wp.mesh_eval_position(
mesh[i], collide_result.face, collide_result.u, collide_result.v
)
delta = cl_pt - sphere_int
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
euclidean_distance = dist
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
# cl_pt = sign * (delta) / dis_length
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
closest_distance += dist_metric
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_point += grad_vec
dist = max(euclidean_distance - in_rad, in_rad)
jump_distance += dist
else:
dist = max(-dist - in_rad, in_rad)
jump_distance += dist
else:
jump_distance += max_dist_buffer
j += 1
if jump_distance >= sphere_2_distance:
j = int(sweep_steps)
i += 1
# return
if closest_distance <= 0.0:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
distance[tid] = 0.0
if write_grad == 1:
closest_pt[tid * 4 + 0] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
if enable_speed_metric == 1 and (h_idx > 0 and h_idx < horizon - 1):
# calculate sphere velocity and acceleration:
norm_vel_vec = wp.vec3(0.0)
sph_acc_vec = wp.vec3(0.0)
sph_vel = wp.float(0.0)
# use central difference
norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
sph_vel = wp.length(norm_vel_vec)
if sph_vel > 1e-3:
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
norm_vel_vec = norm_vel_vec * (1.0 / sph_vel)
curvature_vec = sph_acc_vec / (sph_vel * sph_vel)
orth_proj = wp.mat33(0.0)
for i in range(3):
for j in range(3):
orth_proj[i, j] = -1.0 * norm_vel_vec[i] * norm_vel_vec[j]
orth_proj[0, 0] = orth_proj[0, 0] + 1.0
orth_proj[1, 1] = orth_proj[1, 1] + 1.0
orth_proj[2, 2] = orth_proj[2, 2] + 1.0
orth_curv = wp.vec3(
0.0, 0.0, 0.0
) # closest_distance * (orth_proj @ curvature_vec) #wp.matmul(orth_proj, curvature_vec)
orth_pt = wp.vec3(0.0, 0.0, 0.0) # orth_proj @ closest_point
for i in range(3):
orth_pt[i] = (
orth_proj[i, 0] * closest_point[0]
+ orth_proj[i, 1] * closest_point[1]
+ orth_proj[i, 2] * closest_point[2]
)
orth_curv[i] = closest_distance * (
orth_proj[i, 0] * curvature_vec[0]
+ orth_proj[i, 1] * curvature_vec[1]
+ orth_proj[i, 2] * curvature_vec[2]
)
closest_point = sph_vel * (orth_pt - orth_curv)
closest_distance = sph_vel * closest_distance
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
@wp.kernel
def get_closest_pt_batch_env(
pt: wp.array(dtype=wp.vec4),
distance: wp.array(dtype=wp.float32), # this stores the output cost
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
sparsity_idx: wp.array(dtype=wp.uint8),
weight: wp.array(dtype=wp.float32),
activation_distance: wp.array(dtype=wp.float32), # eta threshold
mesh: wp.array(dtype=wp.uint64),
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.array(dtype=wp.float32),
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
env_query_idx: wp.array(dtype=wp.int32),
use_batch_env: wp.uint8,
compute_esdf: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
# distance is negative outside and positive inside
tid = wp.tid()
n_mesh = int(0)
b_idx = int(0)
h_idx = int(0)
sph_idx = int(0)
env_idx = int(0)
b_idx = tid / (horizon * nspheres)
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
return
sign = float(0.0)
dist = float(0.0)
grad_vec = wp.vec3(0.0)
eta = float(0.0)
dist_metric = float(0.0)
max_dist_buffer = float(0.0)
uint_zero = wp.uint8(0)
uint_one = wp.uint8(1)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[tid]
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
in_rad = in_sphere[3]
if in_rad < 0.0:
distance[tid] = 0.0
if write_grad == 1 and sparsity_idx[tid] == uint_one:
sparsity_idx[tid] = uint_zero
closest_pt[tid * 4] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
eta = activation_distance[0]
if compute_esdf != 1:
in_rad += eta
max_dist_buffer = max_dist[0]
if (in_rad) > max_dist_buffer:
max_dist_buffer += in_rad
# TODO: read vec4 and use first 3 for sphere position and last one for radius
# in_pt = pt[tid]
closest_distance = float(0.0)
closest_point = wp.vec3(0.0)
dis_length = float(0.0)
# read env index:
if use_batch_env:
env_idx = env_query_idx[b_idx]
# read number of boxes in current environment:
# get start index
i = int(0)
i = max_nmesh * env_idx
n_mesh = i + n_env_mesh[env_idx]
obj_position = wp.vec3()
max_dist_value = -1.0 * max_dist_buffer
if compute_esdf == 1:
closest_distance = max_dist_value
while i < n_mesh:
if mesh_enable[i] == uint_one:
# transform point to mesh frame:
# mesh_pt = T_inverse @ w_pt
obj_position[0] = mesh_pose[i * 8 + 0]
obj_position[1] = mesh_pose[i * 8 + 1]
obj_position[2] = mesh_pose[i * 8 + 2]
obj_quat = wp.quaternion(
mesh_pose[i * 8 + 4],
mesh_pose[i * 8 + 5],
mesh_pose[i * 8 + 6],
mesh_pose[i * 8 + 3],
)
obj_w_pose = wp.transform(obj_position, obj_quat)
local_pt = wp.transform_point(obj_w_pose, in_pt)
collide_result = mesh_query_point_fn(mesh[i], local_pt, max_dist_buffer)
if collide_result.result:
sign = collide_result.sign
cl_pt = wp.mesh_eval_position(
mesh[i], collide_result.face, collide_result.u, collide_result.v
)
delta = cl_pt - local_pt
dis_length = wp.length(delta)
dist = -1.0 * dis_length * sign
if compute_esdf == 1:
if dist > max_dist_value:
max_dist_value = dist
closest_distance = dist
if write_grad == 1:
cl_pt = sign * (delta) / dis_length
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
closest_point = grad_vec
else:
dist = dist + in_rad
if dist > 0:
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
closest_distance += dist_metric
closest_point += grad_vec
i += 1
if closest_distance == 0 and compute_esdf != 1:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
distance[tid] = 0.0
if write_grad == 1:
closest_pt[tid * 4 + 0] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
else:
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
class SdfMeshWarpPy(torch.autograd.Function): class SdfMeshWarpPy(torch.autograd.Function):

View File

@@ -0,0 +1,506 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Third Party
import warp as wp
@wp.func
def mesh_query_point_fn(
idx: wp.uint64,
point: wp.vec3,
max_distance: float,
):
collide_result = wp.mesh_query_point(idx, point, max_distance)
return collide_result
@wp.kernel
def get_swept_closest_pt_batch_env(
pt: wp.array(dtype=wp.vec4),
distance: wp.array(dtype=wp.float32), # this stores the output cost
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
sparsity_idx: wp.array(dtype=wp.uint8),
weight: wp.array(dtype=wp.float32),
activation_distance: wp.array(dtype=wp.float32), # eta threshold
speed_dt: wp.array(dtype=wp.float32),
mesh: wp.array(dtype=wp.uint64),
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.array(dtype=wp.float32),
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
sweep_steps: wp.uint8,
enable_speed_metric: wp.uint8,
env_query_idx: wp.array(dtype=wp.int32),
use_batch_env: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
# distance is negative outside and positive inside
tid = int(0)
tid = wp.tid()
b_idx = int(0)
h_idx = int(0)
sph_idx = int(0)
# read horizon
b_idx = tid / (horizon * nspheres)
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
return
uint_zero = wp.uint8(0)
uint_one = wp.uint8(1)
env_idx = int(0)
n_mesh = int(0)
# $wp.printf("%d, %d, %d, %d \n", tid, b_idx, h_idx, sph_idx)
# read sphere
sphere_0_distance = float(0.0)
sphere_2_distance = float(0.0)
sphere_0 = wp.vec3(0.0)
sphere_2 = wp.vec3(0.0)
sphere_int = wp.vec3(0.0)
sphere_temp = wp.vec3(0.0)
grad_vec = wp.vec3(0.0)
eta = float(0.0)
dt = float(0.0)
k0 = float(0.0)
sign = float(0.0)
dist = float(0.0)
dist_metric = float(0.0)
euclidean_distance = float(0.0)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
in_rad = in_sphere[3]
if in_rad < 0.0:
distance[tid] = 0.0
if write_grad == 1 and sparsity_idx[tid] == uint_one:
sparsity_idx[tid] = uint_zero
closest_pt[tid * 4] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
dt = speed_dt[0]
eta = activation_distance[0]
in_rad += eta
max_dist_buffer = float(0.0)
max_dist_buffer = max_dist[0]
if (in_rad) > max_dist_buffer:
max_dist_buffer += in_rad
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
# read in sphere 0:
if h_idx > 0:
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx - 1) * nspheres) + sph_idx]
sphere_0 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
sphere_0_distance = wp.length(sphere_0 - in_pt) / 2.0
if h_idx < horizon - 1:
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx + 1) * nspheres) + sph_idx]
sphere_2 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
sphere_2_distance = wp.length(sphere_2 - in_pt) / 2.0
# read in sphere 2:
closest_distance = float(0.0)
closest_point = wp.vec3(0.0)
i = int(0)
dis_length = float(0.0)
jump_distance = float(0.0)
mid_distance = float(0.0)
if use_batch_env:
env_idx = env_query_idx[b_idx]
i = max_nmesh * env_idx
n_mesh = i + n_env_mesh[env_idx]
obj_position = wp.vec3()
while i < n_mesh:
if mesh_enable[i] == uint_one:
# transform point to mesh frame:
# mesh_pt = T_inverse @ w_pt
obj_position[0] = mesh_pose[i * 8 + 0]
obj_position[1] = mesh_pose[i * 8 + 1]
obj_position[2] = mesh_pose[i * 8 + 2]
obj_quat = wp.quaternion(
mesh_pose[i * 8 + 4],
mesh_pose[i * 8 + 5],
mesh_pose[i * 8 + 6],
mesh_pose[i * 8 + 3],
)
obj_w_pose = wp.transform(obj_position, obj_quat)
obj_w_pose_t = wp.transform_inverse(obj_w_pose)
local_pt = wp.transform_point(obj_w_pose, in_pt)
collide_result = mesh_query_point_fn(mesh[i], local_pt, max_dist_buffer)
if collide_result.result:
sign = collide_result.sign
cl_pt = wp.mesh_eval_position(
mesh[i], collide_result.face, collide_result.u, collide_result.v
)
delta = cl_pt - local_pt
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
euclidean_distance = dist
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_distance += dist_metric
closest_point += grad_vec
else:
dist = -1.0 * dist
euclidean_distance = dist
else:
dist = max_dist_buffer
euclidean_distance = dist
dist = max(euclidean_distance - in_rad, in_rad)
mid_distance = euclidean_distance
# transform sphere -1
if h_idx > 0 and mid_distance < sphere_0_distance:
jump_distance = mid_distance
j = int(0)
sphere_temp = wp.transform_point(obj_w_pose, sphere_0)
while j < sweep_steps:
k0 = (
1.0 - 0.5 * jump_distance / sphere_0_distance
) # dist could be greater than sphere_0_distance here?
sphere_int = k0 * local_pt + ((1.0 - k0) * sphere_temp)
collide_result = mesh_query_point_fn(mesh[i], sphere_int, max_dist_buffer)
if collide_result.result:
sign = collide_result.sign
cl_pt = wp.mesh_eval_position(
mesh[i], collide_result.face, collide_result.u, collide_result.v
)
delta = cl_pt - sphere_int
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
euclidean_distance = dist
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
closest_distance += dist_metric
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_point += grad_vec
dist = max(euclidean_distance - in_rad, in_rad)
jump_distance += euclidean_distance
else:
dist = max(-dist - in_rad, in_rad)
euclidean_distance = dist
jump_distance += euclidean_distance
else:
jump_distance += max_dist_buffer
j += 1
if jump_distance >= sphere_0_distance:
j = int(sweep_steps)
# transform sphere -1
if h_idx < horizon - 1 and mid_distance < sphere_2_distance:
jump_distance = mid_distance
j = int(0)
sphere_temp = wp.transform_point(obj_w_pose, sphere_2)
while j < sweep_steps:
k0 = (
1.0 - 0.5 * jump_distance / sphere_2_distance
) # dist could be greater than sphere_0_distance here?
sphere_int = k0 * local_pt + (1.0 - k0) * sphere_temp
collide_result = mesh_query_point_fn(mesh[i], sphere_int, max_dist_buffer)
if collide_result.result:
sign = collide_result.sign
cl_pt = wp.mesh_eval_position(
mesh[i], collide_result.face, collide_result.u, collide_result.v
)
delta = cl_pt - sphere_int
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
euclidean_distance = dist
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
# cl_pt = sign * (delta) / dis_length
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
closest_distance += dist_metric
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_point += grad_vec
dist = max(euclidean_distance - in_rad, in_rad)
jump_distance += dist
else:
dist = max(-dist - in_rad, in_rad)
jump_distance += dist
else:
jump_distance += max_dist_buffer
j += 1
if jump_distance >= sphere_2_distance:
j = int(sweep_steps)
i += 1
# return
if closest_distance <= 0.0:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
distance[tid] = 0.0
if write_grad == 1:
closest_pt[tid * 4 + 0] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
if enable_speed_metric == 1 and (h_idx > 0 and h_idx < horizon - 1):
# calculate sphere velocity and acceleration:
norm_vel_vec = wp.vec3(0.0)
sph_acc_vec = wp.vec3(0.0)
sph_vel = wp.float(0.0)
# use central difference
norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
sph_vel = wp.length(norm_vel_vec)
if sph_vel > 1e-3:
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
norm_vel_vec = norm_vel_vec * (1.0 / sph_vel)
curvature_vec = sph_acc_vec / (sph_vel * sph_vel)
orth_proj = wp.mat33(0.0)
for i in range(3):
for j in range(3):
orth_proj[i, j] = -1.0 * norm_vel_vec[i] * norm_vel_vec[j]
orth_proj[0, 0] = orth_proj[0, 0] + 1.0
orth_proj[1, 1] = orth_proj[1, 1] + 1.0
orth_proj[2, 2] = orth_proj[2, 2] + 1.0
orth_curv = wp.vec3(
0.0, 0.0, 0.0
) # closest_distance * (orth_proj @ curvature_vec) #wp.matmul(orth_proj, curvature_vec)
orth_pt = wp.vec3(0.0, 0.0, 0.0) # orth_proj @ closest_point
for i in range(3):
orth_pt[i] = (
orth_proj[i, 0] * closest_point[0]
+ orth_proj[i, 1] * closest_point[1]
+ orth_proj[i, 2] * closest_point[2]
)
orth_curv[i] = closest_distance * (
orth_proj[i, 0] * curvature_vec[0]
+ orth_proj[i, 1] * curvature_vec[1]
+ orth_proj[i, 2] * curvature_vec[2]
)
closest_point = sph_vel * (orth_pt - orth_curv)
closest_distance = sph_vel * closest_distance
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
@wp.kernel
def get_closest_pt_batch_env(
pt: wp.array(dtype=wp.vec4),
distance: wp.array(dtype=wp.float32), # this stores the output cost
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
sparsity_idx: wp.array(dtype=wp.uint8),
weight: wp.array(dtype=wp.float32),
activation_distance: wp.array(dtype=wp.float32), # eta threshold
mesh: wp.array(dtype=wp.uint64),
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.array(dtype=wp.float32),
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
env_query_idx: wp.array(dtype=wp.int32),
use_batch_env: wp.uint8,
compute_esdf: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
# distance is negative outside and positive inside
tid = wp.tid()
n_mesh = int(0)
b_idx = int(0)
h_idx = int(0)
sph_idx = int(0)
env_idx = int(0)
b_idx = tid / (horizon * nspheres)
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
return
sign = float(0.0)
dist = float(0.0)
grad_vec = wp.vec3(0.0)
eta = float(0.0)
dist_metric = float(0.0)
max_dist_buffer = float(0.0)
uint_zero = wp.uint8(0)
uint_one = wp.uint8(1)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[tid]
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
in_rad = in_sphere[3]
if in_rad < 0.0:
distance[tid] = 0.0
if write_grad == 1 and sparsity_idx[tid] == uint_one:
sparsity_idx[tid] = uint_zero
closest_pt[tid * 4] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
eta = activation_distance[0]
if compute_esdf != 1:
in_rad += eta
max_dist_buffer = max_dist[0]
if (in_rad) > max_dist_buffer:
max_dist_buffer += in_rad
# TODO: read vec4 and use first 3 for sphere position and last one for radius
# in_pt = pt[tid]
closest_distance = float(0.0)
closest_point = wp.vec3(0.0)
dis_length = float(0.0)
# read env index:
if use_batch_env:
env_idx = env_query_idx[b_idx]
# read number of boxes in current environment:
# get start index
i = int(0)
i = max_nmesh * env_idx
n_mesh = i + n_env_mesh[env_idx]
obj_position = wp.vec3()
max_dist_value = -1.0 * max_dist_buffer
if compute_esdf == 1:
closest_distance = max_dist_value
while i < n_mesh:
if mesh_enable[i] == uint_one:
# transform point to mesh frame:
# mesh_pt = T_inverse @ w_pt
obj_position[0] = mesh_pose[i * 8 + 0]
obj_position[1] = mesh_pose[i * 8 + 1]
obj_position[2] = mesh_pose[i * 8 + 2]
obj_quat = wp.quaternion(
mesh_pose[i * 8 + 4],
mesh_pose[i * 8 + 5],
mesh_pose[i * 8 + 6],
mesh_pose[i * 8 + 3],
)
obj_w_pose = wp.transform(obj_position, obj_quat)
local_pt = wp.transform_point(obj_w_pose, in_pt)
collide_result = mesh_query_point_fn(mesh[i], local_pt, max_dist_buffer)
if collide_result.result:
sign = collide_result.sign
cl_pt = wp.mesh_eval_position(
mesh[i], collide_result.face, collide_result.u, collide_result.v
)
delta = cl_pt - local_pt
dis_length = wp.length(delta)
dist = -1.0 * dis_length * sign
if compute_esdf == 1:
if dist > max_dist_value:
max_dist_value = dist
closest_distance = dist
if write_grad == 1:
cl_pt = sign * (delta) / dis_length
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
closest_point = grad_vec
else:
dist = dist + in_rad
if dist > 0:
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
closest_distance += dist_metric
closest_point += grad_vec
i += 1
if closest_distance == 0 and compute_esdf != 1:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
distance[tid] = 0.0
if write_grad == 1:
closest_pt[tid * 4 + 0] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
else:
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]

View File

@@ -0,0 +1,496 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Third Party
import warp as wp
@wp.kernel
def get_swept_closest_pt_batch_env(
pt: wp.array(dtype=wp.vec4),
distance: wp.array(dtype=wp.float32), # this stores the output cost
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
sparsity_idx: wp.array(dtype=wp.uint8),
weight: wp.array(dtype=wp.float32),
activation_distance: wp.array(dtype=wp.float32), # eta threshold
speed_dt: wp.array(dtype=wp.float32),
mesh: wp.array(dtype=wp.uint64),
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.array(dtype=wp.float32),
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
sweep_steps: wp.uint8,
enable_speed_metric: wp.uint8,
env_query_idx: wp.array(dtype=wp.int32),
use_batch_env: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
# distance is negative outside and positive inside
tid = int(0)
tid = wp.tid()
b_idx = int(0)
h_idx = int(0)
sph_idx = int(0)
# read horizon
b_idx = tid / (horizon * nspheres)
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
return
uint_zero = wp.uint8(0)
uint_one = wp.uint8(1)
env_idx = int(0)
n_mesh = int(0)
# read sphere
sphere_0_distance = float(0.0)
sphere_2_distance = float(0.0)
sphere_0 = wp.vec3(0.0)
sphere_2 = wp.vec3(0.0)
sphere_int = wp.vec3(0.0)
sphere_temp = wp.vec3(0.0)
grad_vec = wp.vec3(0.0)
eta = float(0.0)
dt = float(0.0)
k0 = float(0.0)
face_index = int(0)
face_u = float(0.0)
face_v = float(0.0)
sign = float(0.0)
dist = float(0.0)
dist_metric = float(0.0)
euclidean_distance = float(0.0)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
in_rad = in_sphere[3]
if in_rad < 0.0:
distance[tid] = 0.0
if write_grad == 1 and sparsity_idx[tid] == uint_one:
sparsity_idx[tid] = uint_zero
closest_pt[tid * 4] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
dt = speed_dt[0]
eta = activation_distance[0]
in_rad += eta
max_dist_buffer = float(0.0)
max_dist_buffer = max_dist[0]
if (in_rad) > max_dist_buffer:
max_dist_buffer += in_rad
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
# read in sphere 0:
if h_idx > 0:
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx - 1) * nspheres) + sph_idx]
sphere_0 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
sphere_0_distance = wp.length(sphere_0 - in_pt) / 2.0
if h_idx < horizon - 1:
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx + 1) * nspheres) + sph_idx]
sphere_2 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
sphere_2_distance = wp.length(sphere_2 - in_pt) / 2.0
# read in sphere 2:
closest_distance = float(0.0)
closest_point = wp.vec3(0.0)
i = int(0)
dis_length = float(0.0)
jump_distance = float(0.0)
mid_distance = float(0.0)
if use_batch_env:
env_idx = env_query_idx[b_idx]
i = max_nmesh * env_idx
n_mesh = i + n_env_mesh[env_idx]
obj_position = wp.vec3()
while i < n_mesh:
if mesh_enable[i] == uint_one:
# transform point to mesh frame:
# mesh_pt = T_inverse @ w_pt
obj_position[0] = mesh_pose[i * 8 + 0]
obj_position[1] = mesh_pose[i * 8 + 1]
obj_position[2] = mesh_pose[i * 8 + 2]
obj_quat = wp.quaternion(
mesh_pose[i * 8 + 4],
mesh_pose[i * 8 + 5],
mesh_pose[i * 8 + 6],
mesh_pose[i * 8 + 3],
)
obj_w_pose = wp.transform(obj_position, obj_quat)
obj_w_pose_t = wp.transform_inverse(obj_w_pose)
local_pt = wp.transform_point(obj_w_pose, in_pt)
if wp.mesh_query_point(
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
delta = cl_pt - local_pt
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
euclidean_distance = dist
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_distance += dist_metric
closest_point += grad_vec
else:
dist = -1.0 * dist
euclidean_distance = dist
else:
dist = max_dist_buffer
euclidean_distance = dist
dist = max(euclidean_distance - in_rad, in_rad)
mid_distance = euclidean_distance
# transform sphere -1
if h_idx > 0 and mid_distance < sphere_0_distance:
jump_distance = mid_distance
j = int(0)
sphere_temp = wp.transform_point(obj_w_pose, sphere_0)
while j < sweep_steps:
k0 = (
1.0 - 0.5 * jump_distance / sphere_0_distance
) # dist could be greater than sphere_0_distance here?
sphere_int = k0 * local_pt + ((1.0 - k0) * sphere_temp)
if wp.mesh_query_point(
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
delta = cl_pt - sphere_int
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
euclidean_distance = dist
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
closest_distance += dist_metric
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_point += grad_vec
dist = max(euclidean_distance - in_rad, in_rad)
jump_distance += euclidean_distance
else:
dist = max(-dist - in_rad, in_rad)
euclidean_distance = dist
jump_distance += euclidean_distance
else:
jump_distance += max_dist_buffer
j += 1
if jump_distance >= sphere_0_distance:
j = int(sweep_steps)
# transform sphere -1
if h_idx < horizon - 1 and mid_distance < sphere_2_distance:
jump_distance = mid_distance
j = int(0)
sphere_temp = wp.transform_point(obj_w_pose, sphere_2)
while j < sweep_steps:
k0 = (
1.0 - 0.5 * jump_distance / sphere_2_distance
) # dist could be greater than sphere_0_distance here?
sphere_int = k0 * local_pt + (1.0 - k0) * sphere_temp
if wp.mesh_query_point(
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
delta = cl_pt - sphere_int
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
euclidean_distance = dist
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
# cl_pt = sign * (delta) / dis_length
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
closest_distance += dist_metric
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_point += grad_vec
dist = max(euclidean_distance - in_rad, in_rad)
jump_distance += dist
else:
dist = max(-dist - in_rad, in_rad)
jump_distance += dist
else:
jump_distance += max_dist_buffer
j += 1
if jump_distance >= sphere_2_distance:
j = int(sweep_steps)
i += 1
# return
if closest_distance <= 0.0:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
distance[tid] = 0.0
if write_grad == 1:
closest_pt[tid * 4 + 0] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
if enable_speed_metric == 1 and (h_idx > 0 and h_idx < horizon - 1):
# calculate sphere velocity and acceleration:
norm_vel_vec = wp.vec3(0.0)
sph_acc_vec = wp.vec3(0.0)
sph_vel = wp.float(0.0)
# use central difference
norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
sph_vel = wp.length(norm_vel_vec)
if sph_vel > 1e-3:
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
norm_vel_vec = norm_vel_vec * (1.0 / sph_vel)
curvature_vec = sph_acc_vec / (sph_vel * sph_vel)
orth_proj = wp.mat33(0.0)
for i in range(3):
for j in range(3):
orth_proj[i, j] = -1.0 * norm_vel_vec[i] * norm_vel_vec[j]
orth_proj[0, 0] = orth_proj[0, 0] + 1.0
orth_proj[1, 1] = orth_proj[1, 1] + 1.0
orth_proj[2, 2] = orth_proj[2, 2] + 1.0
orth_curv = wp.vec3(
0.0, 0.0, 0.0
) # closest_distance * (orth_proj @ curvature_vec) #wp.matmul(orth_proj, curvature_vec)
orth_pt = wp.vec3(0.0, 0.0, 0.0) # orth_proj @ closest_point
for i in range(3):
orth_pt[i] = (
orth_proj[i, 0] * closest_point[0]
+ orth_proj[i, 1] * closest_point[1]
+ orth_proj[i, 2] * closest_point[2]
)
orth_curv[i] = closest_distance * (
orth_proj[i, 0] * curvature_vec[0]
+ orth_proj[i, 1] * curvature_vec[1]
+ orth_proj[i, 2] * curvature_vec[2]
)
closest_point = sph_vel * (orth_pt - orth_curv)
closest_distance = sph_vel * closest_distance
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
@wp.kernel
def get_closest_pt_batch_env(
pt: wp.array(dtype=wp.vec4),
distance: wp.array(dtype=wp.float32), # this stores the output cost
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
sparsity_idx: wp.array(dtype=wp.uint8),
weight: wp.array(dtype=wp.float32),
activation_distance: wp.array(dtype=wp.float32), # eta threshold
mesh: wp.array(dtype=wp.uint64),
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.array(dtype=wp.float32),
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
env_query_idx: wp.array(dtype=wp.int32),
use_batch_env: wp.uint8,
compute_esdf: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
# distance is negative outside and positive inside
tid = wp.tid()
n_mesh = int(0)
b_idx = int(0)
h_idx = int(0)
sph_idx = int(0)
env_idx = int(0)
b_idx = tid / (horizon * nspheres)
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
return
face_index = int(0)
face_u = float(0.0)
face_v = float(0.0)
sign = float(0.0)
dist = float(0.0)
grad_vec = wp.vec3(0.0)
eta = float(0.0)
dist_metric = float(0.0)
max_dist_buffer = float(0.0)
uint_zero = wp.uint8(0)
uint_one = wp.uint8(1)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[tid]
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
in_rad = in_sphere[3]
if in_rad < 0.0:
distance[tid] = 0.0
if write_grad == 1 and sparsity_idx[tid] == uint_one:
sparsity_idx[tid] = uint_zero
closest_pt[tid * 4] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
eta = activation_distance[0]
if compute_esdf != 1:
in_rad += eta
max_dist_buffer = max_dist[0]
if (in_rad) > max_dist_buffer:
max_dist_buffer += in_rad
# TODO: read vec4 and use first 3 for sphere position and last one for radius
# in_pt = pt[tid]
closest_distance = float(0.0)
closest_point = wp.vec3(0.0)
dis_length = float(0.0)
# read env index:
if use_batch_env:
env_idx = env_query_idx[b_idx]
# read number of boxes in current environment:
# get start index
i = int(0)
i = max_nmesh * env_idx
n_mesh = i + n_env_mesh[env_idx]
obj_position = wp.vec3()
max_dist_value = -1.0 * max_dist_buffer
if compute_esdf == 1:
closest_distance = max_dist_value
while i < n_mesh:
if mesh_enable[i] == uint_one:
# transform point to mesh frame:
# mesh_pt = T_inverse @ w_pt
obj_position[0] = mesh_pose[i * 8 + 0]
obj_position[1] = mesh_pose[i * 8 + 1]
obj_position[2] = mesh_pose[i * 8 + 2]
obj_quat = wp.quaternion(
mesh_pose[i * 8 + 4],
mesh_pose[i * 8 + 5],
mesh_pose[i * 8 + 6],
mesh_pose[i * 8 + 3],
)
obj_w_pose = wp.transform(obj_position, obj_quat)
local_pt = wp.transform_point(obj_w_pose, in_pt)
if wp.mesh_query_point(
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
delta = cl_pt - local_pt
dis_length = wp.length(delta)
dist = -1.0 * dis_length * sign
if compute_esdf == 1:
if dist > max_dist_value:
max_dist_value = dist
closest_distance = dist
if write_grad == 1:
cl_pt = sign * (delta) / dis_length
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
closest_point = grad_vec
else:
dist = dist + in_rad
if dist > 0:
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
closest_distance += dist_metric
closest_point += grad_vec
i += 1
if closest_distance == 0 and compute_esdf != 1:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
distance[tid] = 0.0
if write_grad == 1:
closest_pt[tid * 4 + 0] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
else:
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]

View File

@@ -1550,7 +1550,10 @@ def make_bound_pos_smooth_kernel(dof_template: int):
out_grad_a[b_addrs + i] = g_a[i] out_grad_a[b_addrs + i] = g_a[i]
out_grad_j[b_addrs + i] = g_j[i] out_grad_j[b_addrs + i] = g_j[i]
return wp.Kernel(forward_bound_smooth_loop_warp) module = wp.get_module(forward_bound_smooth_loop_warp.__module__)
key = "forward_bound_smooth_loop_warp_" + str(dof_template)
return wp.Kernel(forward_bound_smooth_loop_warp, key=key, module=module)
@get_cache_fn_decorator() @get_cache_fn_decorator()
@@ -1647,4 +1650,6 @@ def make_bound_pos_kernel(dof_template: int):
for i in range(dof_template): for i in range(dof_template):
out_grad_p[b_addrs + i] = g_p[i] out_grad_p[b_addrs + i] = g_p[i]
return wp.Kernel(forward_bound_pos_loop_warp) module = wp.get_module(forward_bound_pos_loop_warp.__module__)
key = "forward_bound_pos_loop_warp_" + str(dof_template)
return wp.Kernel(forward_bound_pos_loop_warp, key=key, module=module)

View File

@@ -224,7 +224,9 @@ def make_l2_kernel(dof_template: int):
for i in range(dof_template): for i in range(dof_template):
out_grad_p[b_addrs + i] = g_p[i] out_grad_p[b_addrs + i] = g_p[i]
return wp.Kernel(forward_l2_loop_warp) module = wp.get_module(forward_l2_loop_warp.__module__)
key = "forward_l2_loop" + str(dof_template)
return wp.Kernel(forward_l2_loop_warp, key=key, module=module)
# create a bound cost tensor: # create a bound cost tensor:

View File

@@ -466,6 +466,8 @@ class PoseCost(CostBase, PoseCostConfig):
batch_pose_idx: torch.Tensor, batch_pose_idx: torch.Tensor,
mode: PoseErrorType = PoseErrorType.BATCH_GOAL, mode: PoseErrorType = PoseErrorType.BATCH_GOAL,
): ):
if len(query_pose.position.shape) == 2:
log_error("Query pose should be [batch, horizon, -1]")
ee_goal_pos = goal_pose.position ee_goal_pos = goal_pose.position
ee_goal_quat = goal_pose.quaternion ee_goal_quat = goal_pose.quaternion
self.cost_type = mode self.cost_type = mode
@@ -476,9 +478,9 @@ class PoseCost(CostBase, PoseCostConfig):
num_goals = 1 num_goals = 1
distance = PoseError.apply( distance = PoseError.apply(
query_pose.position.unsqueeze(1), query_pose.position,
ee_goal_pos, ee_goal_pos,
query_pose.quaternion.unsqueeze(1), query_pose.quaternion,
ee_goal_quat, ee_goal_quat,
self.vec_weight, self.vec_weight,
self.weight, self.weight,

View File

@@ -9,11 +9,14 @@
# its affiliates is strictly prohibited. # its affiliates is strictly prohibited.
# #
# Third Party # Third Party
import warp as wp import warp as wp
from packaging import version
# CuRobo # CuRobo
from curobo.types.base import TensorDeviceType from curobo.types.base import TensorDeviceType
from curobo.util.logger import log_info
def init_warp(quiet=True, tensor_args: TensorDeviceType = TensorDeviceType()): def init_warp(quiet=True, tensor_args: TensorDeviceType = TensorDeviceType()):
@@ -26,3 +29,19 @@ def init_warp(quiet=True, tensor_args: TensorDeviceType = TensorDeviceType()):
# wp.force_load(wp.device_from_torch(tensor_args.device)) # wp.force_load(wp.device_from_torch(tensor_args.device))
return True return True
def warp_support_sdf_struct(wp_module=None):
if wp_module is None:
wp_module = wp
wp_version = wp_module.config.version
if version.parse(wp_version) < version.parse("1.0.0"):
log_info(
"Warp version is "
+ wp_version
+ " < 1.0.0, using older sdf kernels."
+ "No issues expected."
)
return False
return True

View File

@@ -190,3 +190,11 @@ def get_multi_arm_robot_list() -> List[str]:
"quad_ur10e.yml", "quad_ur10e.yml",
] ]
return robot_list return robot_list
def merge_dict_a_into_b(a, b):
for k, v in a.items():
if isinstance(v, dict):
merge_dict_a_into_b(v, b[k])
else:
b[k] = v

View File

@@ -343,9 +343,11 @@ class RobotWorld(RobotWorldConfig):
d_mask = mask(d_self, d_world, d_bound) d_mask = mask(d_self, d_world, d_bound)
return d_mask return d_mask
def pose_distance(self, x_des: Pose, x_current: Pose): def pose_distance(self, x_des: Pose, x_current: Pose, resize: bool = False):
unsqueeze = False
if len(x_current.position.shape) == 2: if len(x_current.position.shape) == 2:
x_current = x_current.unsqueeze(1) x_current = x_current.unsqueeze(1)
unsqueeze = True
# calculate pose loss: # calculate pose loss:
if ( if (
self._batch_pose_idx is None self._batch_pose_idx is None
@@ -355,6 +357,8 @@ class RobotWorld(RobotWorldConfig):
0, x_current.position.shape[0], 1, device=self.tensor_args.device, dtype=torch.int32 0, x_current.position.shape[0], 1, device=self.tensor_args.device, dtype=torch.int32
) )
distance = self.pose_cost.forward_pose(x_des, x_current, self._batch_pose_idx) distance = self.pose_cost.forward_pose(x_des, x_current, self._batch_pose_idx)
if unsqueeze and resize:
distance = distance.squeeze(1)
return distance return distance
def get_point_robot_distance(self, points: torch.Tensor, q: torch.Tensor): def get_point_robot_distance(self, points: torch.Tensor, q: torch.Tensor):

View File

@@ -61,6 +61,7 @@ from curobo.util_file import (
get_world_configs_path, get_world_configs_path,
join_path, join_path,
load_yaml, load_yaml,
merge_dict_a_into_b,
) )
from curobo.wrap.reacher.types import ReacherSolveState, ReacherSolveType from curobo.wrap.reacher.types import ReacherSolveState, ReacherSolveType
from curobo.wrap.wrap_base import WrapResult from curobo.wrap.wrap_base import WrapResult
@@ -107,6 +108,8 @@ class MpcSolverConfig:
step_dt: Optional[float] = None, step_dt: Optional[float] = None,
use_lbfgs: bool = False, use_lbfgs: bool = False,
use_mppi: bool = True, use_mppi: bool = True,
particle_file: str = "particle_mpc.yml",
override_particle_file: str = None,
): ):
"""Create an MPC solver configuration from robot and world configuration. """Create an MPC solver configuration from robot and world configuration.
@@ -151,6 +154,9 @@ class MpcSolverConfig:
time for a single step. time for a single step.
use_lbfgs: Use L-BFGS solver for MPC. Highly experimental. use_lbfgs: Use L-BFGS solver for MPC. Highly experimental.
use_mppi: Use MPPI solver for MPC. use_mppi: Use MPPI solver for MPC.
particle_file: Particle based MPC config file.
override_particle_file: Optional config file for overriding the parameters in the
particle based MPC config file.
Returns: Returns:
MpcSolverConfig: Configuration for the MPC solver. MpcSolverConfig: Configuration for the MPC solver.
@@ -159,8 +165,9 @@ class MpcSolverConfig:
if use_cuda_graph_full_step: if use_cuda_graph_full_step:
log_error("use_cuda_graph_full_step currently is not supported") log_error("use_cuda_graph_full_step currently is not supported")
task_file = "particle_mpc.yml" config_data = load_yaml(join_path(get_task_configs_path(), particle_file))
config_data = load_yaml(join_path(get_task_configs_path(), task_file)) if override_particle_file is not None:
merge_dict_a_into_b(load_yaml(override_particle_file), config_data)
config_data["mppi"]["n_problems"] = 1 config_data["mppi"]["n_problems"] = 1
if step_dt is not None: if step_dt is not None:
config_data["model"]["dt_traj_params"]["base_dt"] = step_dt config_data["model"]["dt_traj_params"]["base_dt"] = step_dt