Significantly improved convergence for mesh and cuboid, new ESDF collision.

This commit is contained in:
Balakumar Sundaralingam
2024-03-18 11:19:48 -07:00
parent 286b3820a5
commit b1f63e8778
100 changed files with 7587 additions and 2589 deletions

View File

@@ -12,8 +12,11 @@
# Third Party
import torch
# CuRobo
from curobo.util.torch_utils import get_torch_jit_decorator
@torch.jit.script
@get_torch_jit_decorator()
def project_depth_to_pointcloud(depth_image: torch.Tensor, intrinsics_matrix: torch.Tensor):
"""Projects numpy depth image to point cloud.
@@ -43,7 +46,7 @@ def project_depth_to_pointcloud(depth_image: torch.Tensor, intrinsics_matrix: to
return raw_pc
@torch.jit.script
@get_torch_jit_decorator()
def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor):
"""Projects numpy depth image to point cloud.
@@ -54,10 +57,10 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
Returns:
array of float (h, w, 3)
"""
fx = intrinsics_matrix[:, 0, 0]
fy = intrinsics_matrix[:, 1, 1]
cx = intrinsics_matrix[:, 0, 2]
cy = intrinsics_matrix[:, 1, 2]
fx = intrinsics_matrix[:, 0:1, 0:1]
fy = intrinsics_matrix[:, 1:2, 1:2]
cx = intrinsics_matrix[:, 0:1, 2:3]
cy = intrinsics_matrix[:, 1:2, 2:3]
input_x = torch.arange(width, dtype=torch.float32, device=intrinsics_matrix.device)
input_y = torch.arange(height, dtype=torch.float32, device=intrinsics_matrix.device)
@@ -73,7 +76,6 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
device=intrinsics_matrix.device,
dtype=torch.float32,
)
output_x = (input_x - cx) / fx
output_y = (input_y - cy) / fy
@@ -84,7 +86,7 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
return rays
@torch.jit.script
@get_torch_jit_decorator()
def project_pointcloud_to_depth(
pointcloud: torch.Tensor,
output_image: torch.Tensor,
@@ -106,7 +108,7 @@ def project_pointcloud_to_depth(
return output_image
@torch.jit.script
@get_torch_jit_decorator()
def project_depth_using_rays(
depth_image: torch.Tensor, rays: torch.Tensor, filter_origin: bool = False
):

View File

@@ -11,8 +11,10 @@
# Third Party
import torch
# from curobo.util.torch_utils import get_torch_jit_decorator
# @torch.jit.script
# @get_torch_jit_decorator()
def lookup_distance(pt, dist_matrix_flat, num_voxels):
# flatten:
ind_pt = (
@@ -22,7 +24,7 @@ def lookup_distance(pt, dist_matrix_flat, num_voxels):
return dist
# @torch.jit.script
# @get_torch_jit_decorator()
def compute_sdf_gradient(pt, dist_matrix_flat, num_voxels, dist):
grad_l = []
for i in range(3): # x,y,z

View File

@@ -30,6 +30,11 @@ def create_collision_checker(config: WorldCollisionConfig):
from curobo.geom.sdf.world_mesh import WorldMeshCollision
return WorldMeshCollision(config)
elif config.checker_type == CollisionCheckerType.VOXEL:
# CuRobo
from curobo.geom.sdf.world_voxel import WorldVoxelCollision
return WorldVoxelCollision(config)
else:
log_error("Not implemented", exc_info=True)
log_error("Unknown Collision Checker type: " + config.checker_type, exc_info=True)
raise NotImplementedError

View File

@@ -16,284 +16,6 @@ import warp as wp
wp.set_module_options({"fast_math": False})
# create warp kernels:
@wp.kernel
def get_swept_closest_pt(
pt: wp.array(dtype=wp.vec4),
distance: wp.array(dtype=wp.float32), # this stores the output cost
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
sparsity_idx: wp.array(dtype=wp.uint8),
weight: wp.array(dtype=wp.float32),
activation_distance: wp.array(dtype=wp.float32), # eta threshold
speed_dt: wp.array(dtype=wp.float32),
mesh: wp.array(dtype=wp.uint64),
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.float32,
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
sweep_steps: wp.uint8,
enable_speed_metric: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
# distance is negative outside and positive inside
tid = int(0)
tid = wp.tid()
b_idx = int(0)
h_idx = int(0)
sph_idx = int(0)
# read horizon
eta = float(0.0) # 5cm buffer
dt = float(1.0)
b_idx = tid / (horizon * nspheres)
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
return
n_mesh = int(0)
# $wp.printf("%d, %d, %d, %d \n", tid, b_idx, h_idx, sph_idx)
# read sphere
sphere_0_distance = float(0.0)
sphere_2_distance = float(0.0)
sphere_0 = wp.vec3(0.0)
sphere_2 = wp.vec3(0.0)
sphere_int = wp.vec3(0.0)
sphere_temp = wp.vec3(0.0)
k0 = float(0.0)
face_index = int(0)
face_u = float(0.0)
face_v = float(0.0)
sign = float(0.0)
dist = float(0.0)
uint_zero = wp.uint8(0)
uint_one = wp.uint8(1)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
in_rad = in_sphere[3]
if in_rad < 0.0:
distance[tid] = 0.0
if write_grad == 1 and sparsity_idx[tid] == uint_one:
sparsity_idx[tid] = uint_zero
closest_pt[tid * 4] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
eta = activation_distance[0]
dt = speed_dt[0]
in_rad += eta
max_dist_buffer = float(0.0)
max_dist_buffer = max_dist
if in_rad > max_dist_buffer:
max_dist_buffer += in_rad
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
# read in sphere 0:
# read in sphere 0:
if h_idx > 0:
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx - 1) * nspheres) + sph_idx]
sphere_0 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
sphere_0_distance = wp.length(sphere_0 - in_pt) / 2.0
if h_idx < horizon - 1:
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx + 1) * nspheres) + sph_idx]
sphere_2 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
sphere_2_distance = wp.length(sphere_2 - in_pt) / 2.0
# read in sphere 2:
closest_distance = float(0.0)
closest_point = wp.vec3(0.0)
i = int(0)
dis_length = float(0.0)
jump_distance = float(0.0)
mid_distance = float(0.0)
n_mesh = n_env_mesh[0]
obj_position = wp.vec3()
while i < n_mesh:
if mesh_enable[i] == uint_one:
obj_position[0] = mesh_pose[i * 8 + 0]
obj_position[1] = mesh_pose[i * 8 + 1]
obj_position[2] = mesh_pose[i * 8 + 2]
obj_quat = wp.quaternion(
mesh_pose[i * 8 + 4],
mesh_pose[i * 8 + 5],
mesh_pose[i * 8 + 6],
mesh_pose[i * 8 + 3],
)
obj_w_pose = wp.transform(obj_position, obj_quat)
obj_w_pose_t = wp.transform_inverse(obj_w_pose)
# transform point to mesh frame:
# mesh_pt = T_inverse @ w_pt
local_pt = wp.transform_point(obj_w_pose, in_pt)
if wp.mesh_query_point(
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - local_pt)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - local_pt) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
closest_distance += dist_metric
closest_point += grad_vec
else:
dist = -1.0 * dist
else:
dist = in_rad
dist = max(dist - in_rad, in_rad)
mid_distance = dist
# transform sphere -1
if h_idx > 0 and mid_distance < sphere_0_distance:
jump_distance = mid_distance
j = int(0)
sphere_temp = wp.transform_point(obj_w_pose, sphere_0)
while j < sweep_steps:
k0 = (
1.0 - 0.5 * jump_distance / sphere_0_distance
) # dist could be greater than sphere_0_distance here?
sphere_int = k0 * local_pt + ((1.0 - k0) * sphere_temp)
if wp.mesh_query_point(
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - sphere_int)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - sphere_int) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
closest_distance += dist_metric
closest_point += grad_vec
dist = max(dist - in_rad, in_rad)
jump_distance += dist
else:
dist = max(-dist - in_rad, in_rad)
jump_distance += dist
else:
jump_distance += max_dist_buffer
j += 1
if jump_distance >= sphere_0_distance:
j = int(sweep_steps)
# transform sphere -1
if h_idx < horizon - 1 and mid_distance < sphere_2_distance:
jump_distance = mid_distance
j = int(0)
sphere_temp = wp.transform_point(obj_w_pose, sphere_2)
while j < sweep_steps:
k0 = (
1.0 - 0.5 * jump_distance / sphere_2_distance
) # dist could be greater than sphere_0_distance here?
sphere_int = k0 * local_pt + (1.0 - k0) * sphere_temp
if wp.mesh_query_point(
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - sphere_int)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - sphere_int) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
closest_distance += dist_metric
closest_point += grad_vec
dist = max(dist - in_rad, in_rad)
jump_distance += dist
else:
dist = max(-dist - in_rad, in_rad)
jump_distance += dist
else:
jump_distance += max_dist_buffer
j += 1
if jump_distance >= sphere_2_distance:
j = int(sweep_steps)
i += 1
# return
if closest_distance == 0:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
distance[tid] = 0.0
if write_grad == 1:
closest_pt[tid * 4 + 0] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
if enable_speed_metric == 1 and (h_idx > 0 and h_idx < horizon - 1):
# calculate sphere velocity and acceleration:
norm_vel_vec = wp.vec3(0.0)
sph_acc_vec = wp.vec3(0.0)
sph_vel = wp.float(0.0)
# use central difference
norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
# norm_vel_vec = -1.0 * norm_vel_vec
# sph_acc_vec = -1.0 * sph_acc_vec
sph_vel = wp.length(norm_vel_vec)
norm_vel_vec = norm_vel_vec / sph_vel
orth_proj = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) - wp.outer(
norm_vel_vec, norm_vel_vec
)
curvature_vec = orth_proj * (sph_acc_vec / (sph_vel * sph_vel))
closest_point = sph_vel * ((orth_proj * closest_point) - closest_distance * curvature_vec)
closest_distance = sph_vel * closest_distance
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
if closest_distance > 0.0:
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
@wp.kernel
def get_swept_closest_pt_batch_env(
pt: wp.array(dtype=wp.vec4),
@@ -307,7 +29,7 @@ def get_swept_closest_pt_batch_env(
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.float32,
max_dist: wp.array(dtype=wp.float32),
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
@@ -316,6 +38,7 @@ def get_swept_closest_pt_batch_env(
sweep_steps: wp.uint8,
enable_speed_metric: wp.uint8,
env_query_idx: wp.array(dtype=wp.int32),
use_batch_env: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
@@ -357,6 +80,7 @@ def get_swept_closest_pt_batch_env(
sign = float(0.0)
dist = float(0.0)
dist_metric = float(0.0)
euclidean_distance = float(0.0)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
@@ -374,7 +98,7 @@ def get_swept_closest_pt_batch_env(
eta = activation_distance[0]
in_rad += eta
max_dist_buffer = float(0.0)
max_dist_buffer = max_dist
max_dist_buffer = max_dist[0]
if (in_rad) > max_dist_buffer:
max_dist_buffer += in_rad
@@ -396,7 +120,8 @@ def get_swept_closest_pt_batch_env(
dis_length = float(0.0)
jump_distance = float(0.0)
mid_distance = float(0.0)
env_idx = env_query_idx[b_idx]
if use_batch_env:
env_idx = env_query_idx[b_idx]
i = max_nmesh * env_idx
n_mesh = i + n_env_mesh[env_idx]
obj_position = wp.vec3()
@@ -423,26 +148,33 @@ def get_swept_closest_pt_batch_env(
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - local_pt)
delta = cl_pt - local_pt
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - local_pt) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
euclidean_distance = dist
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
cl_pt = (1.0 / eta) * dist * cl_pt
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_distance += dist_metric
closest_point += grad_vec
else:
dist = -1.0 * dist
euclidean_distance = dist
else:
dist = max_dist_buffer
dist = max(dist - in_rad, in_rad)
euclidean_distance = dist
dist = max(euclidean_distance - in_rad, in_rad)
mid_distance = dist
mid_distance = euclidean_distance
# transform sphere -1
if h_idx > 0 and mid_distance < sphere_0_distance:
jump_distance = mid_distance
@@ -457,24 +189,31 @@ def get_swept_closest_pt_batch_env(
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - sphere_int)
delta = cl_pt - sphere_int
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - sphere_int) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
euclidean_distance = dist
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
cl_pt = (1.0 / eta) * dist * cl_pt
closest_distance += dist_metric
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_point += grad_vec
dist = max(dist - in_rad, in_rad)
jump_distance += dist
dist = max(euclidean_distance - in_rad, in_rad)
jump_distance += euclidean_distance
else:
dist = max(-dist - in_rad, in_rad)
jump_distance += dist
euclidean_distance = dist
jump_distance += euclidean_distance
else:
jump_distance += max_dist_buffer
j += 1
@@ -495,24 +234,30 @@ def get_swept_closest_pt_batch_env(
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - sphere_int)
delta = cl_pt - sphere_int
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - sphere_int) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
euclidean_distance = dist
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
# cl_pt = sign * (delta) / dis_length
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
cl_pt = (1.0 / eta) * dist * cl_pt
closest_distance += dist_metric
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_point += grad_vec
dist = max(dist - in_rad, in_rad)
dist = max(euclidean_distance - in_rad, in_rad)
jump_distance += dist
else:
dist = max(-dist - in_rad, in_rad)
jump_distance += dist
else:
jump_distance += max_dist_buffer
@@ -542,179 +287,54 @@ def get_swept_closest_pt_batch_env(
# use central difference
norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
# norm_vel_vec = -1.0 * norm_vel_vec
# sph_acc_vec = -1.0 * sph_acc_vec
sph_vel = wp.length(norm_vel_vec)
if sph_vel > 1e-3:
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
norm_vel_vec = norm_vel_vec / sph_vel
norm_vel_vec = norm_vel_vec * (1.0 / sph_vel)
orth_proj = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) - wp.outer(
norm_vel_vec, norm_vel_vec
)
curvature_vec = sph_acc_vec / (sph_vel * sph_vel)
curvature_vec = orth_proj * (sph_acc_vec / (sph_vel * sph_vel))
orth_proj = wp.mat33(0.0)
for i in range(3):
for j in range(3):
orth_proj[i, j] = -1.0 * norm_vel_vec[i] * norm_vel_vec[j]
closest_point = sph_vel * ((orth_proj * closest_point) - closest_distance * curvature_vec)
orth_proj[0, 0] = orth_proj[0, 0] + 1.0
orth_proj[1, 1] = orth_proj[1, 1] + 1.0
orth_proj[2, 2] = orth_proj[2, 2] + 1.0
closest_distance = sph_vel * closest_distance
orth_curv = wp.vec3(
0.0, 0.0, 0.0
) # closest_distance * (orth_proj @ curvature_vec) #wp.matmul(orth_proj, curvature_vec)
orth_pt = wp.vec3(0.0, 0.0, 0.0) # orth_proj @ closest_point
for i in range(3):
orth_pt[i] = (
orth_proj[i, 0] * closest_point[0]
+ orth_proj[i, 1] * closest_point[1]
+ orth_proj[i, 2] * closest_point[2]
)
orth_curv[i] = closest_distance * (
orth_proj[i, 0] * curvature_vec[0]
+ orth_proj[i, 1] * curvature_vec[1]
+ orth_proj[i, 2] * curvature_vec[2]
)
closest_point = sph_vel * (orth_pt - orth_curv)
closest_distance = sph_vel * closest_distance
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
if closest_distance > 0.0:
closest_distance = weight[0]
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
@wp.kernel
def get_closest_pt(
pt: wp.array(dtype=wp.vec4),
distance: wp.array(dtype=wp.float32), # this stores the output cost
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
sparsity_idx: wp.array(dtype=wp.uint8),
weight: wp.array(dtype=wp.float32),
activation_distance: wp.array(dtype=wp.float32), # eta threshold
mesh: wp.array(dtype=wp.uint64),
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.float32,
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
):
# we launch nspheres kernels
# compute gradient here and return
# distance is negative outside and positive inside
tid = wp.tid()
n_mesh = int(0)
b_idx = int(0)
h_idx = int(0)
sph_idx = int(0)
# env_idx = int(0)
b_idx = tid / (horizon * nspheres)
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
return
face_index = int(0)
face_u = float(0.0)
face_v = float(0.0)
sign = float(0.0)
dist = float(0.0)
grad_vec = wp.vec3(0.0)
eta = float(0.05)
dist_metric = float(0.0)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[tid]
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
in_rad = in_sphere[3]
uint_zero = wp.uint8(0)
uint_one = wp.uint8(1)
if in_rad < 0.0:
distance[tid] = 0.0
if write_grad == 1 and sparsity_idx[tid] == uint_one:
sparsity_idx[tid] = uint_zero
closest_pt[tid * 4] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
eta = activation_distance[0]
in_rad += eta
max_dist_buffer = float(0.0)
max_dist_buffer = max_dist
if in_rad > max_dist_buffer:
max_dist_buffer += in_rad
# TODO: read vec4 and use first 3 for sphere position and last one for radius
# in_pt = pt[tid]
closest_distance = float(0.0)
closest_point = wp.vec3(0.0)
i = int(0)
dis_length = float(0.0)
# read env index:
# env_idx = env_query_idx[b_idx]
# read number of boxes in current environment:
# get start index
i = int(0)
n_mesh = n_env_mesh[0]
obj_position = wp.vec3()
# mesh_idx = wp.uint64(0)
while i < n_mesh:
if mesh_enable[i] == uint_one:
# transform point to mesh frame:
# mesh_pt = T_inverse @ w_pt
# read object pose:
obj_position[0] = mesh_pose[i * 8 + 0]
obj_position[1] = mesh_pose[i * 8 + 1]
obj_position[2] = mesh_pose[i * 8 + 2]
obj_quat = wp.quaternion(
mesh_pose[i * 8 + 4],
mesh_pose[i * 8 + 5],
mesh_pose[i * 8 + 6],
mesh_pose[i * 8 + 3],
)
obj_w_pose = wp.transform(obj_position, obj_quat)
local_pt = wp.transform_point(obj_w_pose, in_pt)
# mesh_idx = mesh[i]
if wp.mesh_query_point(
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - local_pt)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - local_pt) / dis_length
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
closest_distance += dist_metric
closest_point += grad_vec
i += 1
if closest_distance == 0:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
distance[tid] = 0.0
if write_grad == 1:
closest_pt[tid * 4 + 0] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
else:
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
if closest_distance > 0.0:
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
@wp.kernel
def get_closest_pt_batch_env(
pt: wp.array(dtype=wp.vec4),
@@ -727,13 +347,15 @@ def get_closest_pt_batch_env(
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.float32,
max_dist: wp.array(dtype=wp.float32),
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
env_query_idx: wp.array(dtype=wp.int32),
use_batch_env: wp.uint8,
compute_esdf: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
@@ -779,8 +401,9 @@ def get_closest_pt_batch_env(
return
eta = activation_distance[0]
in_rad += eta
max_dist_buffer = max_dist
if compute_esdf != 1:
in_rad += eta
max_dist_buffer = max_dist[0]
if (in_rad) > max_dist_buffer:
max_dist_buffer += in_rad
@@ -791,7 +414,9 @@ def get_closest_pt_batch_env(
dis_length = float(0.0)
# read env index:
env_idx = env_query_idx[b_idx]
if use_batch_env:
env_idx = env_query_idx[b_idx]
# read number of boxes in current environment:
# get start index
@@ -799,7 +424,9 @@ def get_closest_pt_batch_env(
i = max_nmesh * env_idx
n_mesh = i + n_env_mesh[env_idx]
obj_position = wp.vec3()
max_dist_value = -1.0 * max_dist_buffer
if compute_esdf == 1:
closest_distance = max_dist_value
while i < n_mesh:
if mesh_enable[i] == uint_one:
# transform point to mesh frame:
@@ -822,21 +449,39 @@ def get_closest_pt_batch_env(
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - local_pt)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - local_pt) / dis_length
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
closest_distance += dist_metric
closest_point += grad_vec
delta = cl_pt - local_pt
dis_length = wp.length(delta)
dist = -1.0 * dis_length * sign
if compute_esdf == 1:
if dist > max_dist_value:
max_dist_value = dist
closest_distance = dist
if write_grad == 1:
cl_pt = sign * (delta) / dis_length
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
closest_point = grad_vec
else:
dist = dist + in_rad
if dist > 0:
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
closest_distance += dist_metric
closest_point += grad_vec
i += 1
if closest_distance == 0:
if closest_distance == 0 and compute_esdf != 1:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
@@ -850,8 +495,7 @@ def get_closest_pt_batch_env(
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
if closest_distance > 0.0:
closest_distance = weight[0]
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
@@ -871,62 +515,43 @@ class SdfMeshWarpPy(torch.autograd.Function):
mesh_pose_inverse,
mesh_enable,
n_env_mesh,
max_dist=0.05,
max_dist,
env_query_idx=None,
return_loss=False,
compute_esdf=False,
):
b, h, n, _ = query_spheres.shape
use_batch_env = True
if env_query_idx is None:
# launch
wp.launch(
kernel=get_closest_pt,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
max_dist,
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
],
stream=wp.stream_from_torch(query_spheres.device),
)
else:
wp.launch(
kernel=get_closest_pt_batch_env,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
max_dist,
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
],
stream=wp.stream_from_torch(query_spheres.device),
)
use_batch_env = False
env_query_idx = n_env_mesh
wp.launch(
kernel=get_closest_pt_batch_env,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
wp.from_torch(max_dist, dtype=wp.float32),
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
use_batch_env,
compute_esdf,
],
stream=wp.stream_from_torch(query_spheres.device),
)
ctx.return_loss = return_loss
ctx.save_for_backward(out_grad)
return out_cost
@@ -939,7 +564,22 @@ class SdfMeshWarpPy(torch.autograd.Function):
grad_sph = r
if ctx.return_loss:
grad_sph = r * grad_output.unsqueeze(-1)
return grad_sph, None, None, None, None, None, None, None, None, None, None, None, None
return (
grad_sph,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class SweptSdfMeshWarpPy(torch.autograd.Function):
@@ -957,69 +597,46 @@ class SweptSdfMeshWarpPy(torch.autograd.Function):
mesh_pose_inverse,
mesh_enable,
n_env_mesh,
max_dist,
sweep_steps=1,
enable_speed_metric=False,
max_dist=0.05,
env_query_idx=None,
return_loss=False,
):
b, h, n, _ = query_spheres.shape
use_batch_env = True
if env_query_idx is None:
wp.launch(
kernel=get_swept_closest_pt,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(speed_dt),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
max_dist,
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
sweep_steps,
enable_speed_metric,
],
stream=wp.stream_from_torch(query_spheres.device),
)
else:
wp.launch(
kernel=get_swept_closest_pt_batch_env,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(speed_dt),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
max_dist,
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
sweep_steps,
enable_speed_metric,
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
],
stream=wp.stream_from_torch(query_spheres.device),
)
use_batch_env = False
env_query_idx = n_env_mesh
wp.launch(
kernel=get_swept_closest_pt_batch_env,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(speed_dt),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
wp.from_torch(max_dist, dtype=wp.float32),
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
sweep_steps,
enable_speed_metric,
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
use_batch_env,
],
stream=wp.stream_from_torch(query_spheres.device),
)
ctx.return_loss = return_loss
ctx.save_for_backward(out_grad)
return out_cost

View File

@@ -19,7 +19,7 @@ import torch
# CuRobo
from curobo.curobolib.geom import SdfSphereOBB, SdfSweptSphereOBB
from curobo.geom.types import Cuboid, Mesh, Obstacle, WorldConfig, batch_tensor_cube
from curobo.geom.types import Cuboid, Mesh, Obstacle, VoxelGrid, WorldConfig, batch_tensor_cube
from curobo.types.base import TensorDeviceType
from curobo.types.math import Pose
from curobo.util.logger import log_error, log_info, log_warn
@@ -39,10 +39,14 @@ class CollisionBuffer:
def initialize_from_shape(cls, shape: torch.Size, tensor_args: TensorDeviceType):
batch, horizon, n_spheres, _ = shape
distance_buffer = torch.zeros(
(batch, horizon, n_spheres), device=tensor_args.device, dtype=tensor_args.dtype
(batch, horizon, n_spheres),
device=tensor_args.device,
dtype=tensor_args.collision_distance_dtype,
)
grad_distance_buffer = torch.zeros(
(batch, horizon, n_spheres, 4), device=tensor_args.device, dtype=tensor_args.dtype
(batch, horizon, n_spheres, 4),
device=tensor_args.device,
dtype=tensor_args.collision_gradient_dtype,
)
sparsity_idx = torch.zeros(
(batch, horizon, n_spheres),
@@ -54,10 +58,14 @@ class CollisionBuffer:
def _update_from_shape(self, shape: torch.Size, tensor_args: TensorDeviceType):
batch, horizon, n_spheres, _ = shape
self.distance_buffer = torch.zeros(
(batch, horizon, n_spheres), device=tensor_args.device, dtype=tensor_args.dtype
(batch, horizon, n_spheres),
device=tensor_args.device,
dtype=tensor_args.collision_distance_dtype,
)
self.grad_distance_buffer = torch.zeros(
(batch, horizon, n_spheres, 4), device=tensor_args.device, dtype=tensor_args.dtype
(batch, horizon, n_spheres, 4),
device=tensor_args.device,
dtype=tensor_args.collision_gradient_dtype,
)
self.sparsity_index_buffer = torch.zeros(
(batch, horizon, n_spheres),
@@ -100,6 +108,7 @@ class CollisionQueryBuffer:
primitive_collision_buffer: Optional[CollisionBuffer] = None
mesh_collision_buffer: Optional[CollisionBuffer] = None
blox_collision_buffer: Optional[CollisionBuffer] = None
voxel_collision_buffer: Optional[CollisionBuffer] = None
shape: Optional[torch.Size] = None
def __post_init__(self):
@@ -110,6 +119,8 @@ class CollisionQueryBuffer:
self.shape = self.mesh_collision_buffer.shape
elif self.blox_collision_buffer is not None:
self.shape = self.blox_collision_buffer.shape
elif self.voxel_collision_buffer is not None:
self.shape = self.voxel_collision_buffer.shape
def __mul__(self, scalar: float):
if self.primitive_collision_buffer is not None:
@@ -118,17 +129,27 @@ class CollisionQueryBuffer:
self.mesh_collision_buffer = self.mesh_collision_buffer * scalar
if self.blox_collision_buffer is not None:
self.blox_collision_buffer = self.blox_collision_buffer * scalar
if self.voxel_collision_buffer is not None:
self.voxel_collision_buffer = self.voxel_collision_buffer * scalar
return self
def clone(self):
prim_buffer = mesh_buffer = blox_buffer = None
prim_buffer = mesh_buffer = blox_buffer = voxel_buffer = None
if self.primitive_collision_buffer is not None:
prim_buffer = self.primitive_collision_buffer.clone()
if self.mesh_collision_buffer is not None:
mesh_buffer = self.mesh_collision_buffer.clone()
if self.blox_collision_buffer is not None:
blox_buffer = self.blox_collision_buffer.clone()
return CollisionQueryBuffer(prim_buffer, mesh_buffer, blox_buffer, self.shape)
if self.voxel_collision_buffer is not None:
voxel_buffer = self.voxel_collision_buffer.clone()
return CollisionQueryBuffer(
prim_buffer,
mesh_buffer,
blox_buffer,
voxel_collision_buffer=voxel_buffer,
shape=self.shape,
)
@classmethod
def initialize_from_shape(
@@ -137,14 +158,18 @@ class CollisionQueryBuffer:
tensor_args: TensorDeviceType,
collision_types: Dict[str, bool],
):
primitive_buffer = mesh_buffer = blox_buffer = None
primitive_buffer = mesh_buffer = blox_buffer = voxel_buffer = None
if "primitive" in collision_types and collision_types["primitive"]:
primitive_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
if "mesh" in collision_types and collision_types["mesh"]:
mesh_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
if "blox" in collision_types and collision_types["blox"]:
blox_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
return CollisionQueryBuffer(primitive_buffer, mesh_buffer, blox_buffer)
if "voxel" in collision_types and collision_types["voxel"]:
voxel_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
return CollisionQueryBuffer(
primitive_buffer, mesh_buffer, blox_buffer, voxel_collision_buffer=voxel_buffer
)
def create_from_shape(
self,
@@ -160,8 +185,9 @@ class CollisionQueryBuffer:
self.mesh_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
if "blox" in collision_types and collision_types["blox"]:
self.blox_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
if "voxel" in collision_types and collision_types["voxel"]:
self.voxel_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
self.shape = shape
# return self
def update_buffer_shape(
self,
@@ -169,12 +195,10 @@ class CollisionQueryBuffer:
tensor_args: TensorDeviceType,
collision_types: Optional[Dict[str, bool]],
):
# print(shape, self.shape)
# update buffers:
assert len(shape) == 4 # shape is: batch, horizon, n_spheres, 4
if self.shape is None: # buffers not initialized:
self.create_from_shape(shape, tensor_args, collision_types)
# print("Creating new memory", self.shape)
else:
# update buffers if shape doesn't match:
# TODO: allow for dynamic change of collision_types
@@ -185,6 +209,8 @@ class CollisionQueryBuffer:
self.mesh_collision_buffer.update_buffer_shape(shape, tensor_args)
if self.blox_collision_buffer is not None:
self.blox_collision_buffer.update_buffer_shape(shape, tensor_args)
if self.voxel_collision_buffer is not None:
self.voxel_collision_buffer.update_buffer_shape(shape, tensor_args)
self.shape = shape
def get_gradient_buffer(
@@ -208,6 +234,12 @@ class CollisionQueryBuffer:
current_buffer = blox_buffer.clone()
else:
current_buffer += blox_buffer
if self.voxel_collision_buffer is not None:
voxel_buffer = self.voxel_collision_buffer.grad_distance_buffer
if current_buffer is None:
current_buffer = voxel_buffer.clone()
else:
current_buffer += voxel_buffer
return current_buffer
@@ -221,6 +253,7 @@ class CollisionCheckerType(Enum):
PRIMITIVE = "PRIMITIVE"
BLOX = "BLOX"
MESH = "MESH"
VOXEL = "VOXEL"
@dataclass
@@ -230,11 +263,13 @@ class WorldCollisionConfig:
cache: Optional[Dict[Obstacle, int]] = None
n_envs: int = 1
checker_type: CollisionCheckerType = CollisionCheckerType.PRIMITIVE
max_distance: float = 0.01
max_distance: Union[torch.Tensor, float] = 0.01
def __post_init__(self):
if self.world_model is not None and isinstance(self.world_model, list):
self.n_envs = len(self.world_model)
if isinstance(self.max_distance, float):
self.max_distance = self.tensor_args.to_device([self.max_distance])
@staticmethod
def load_from_dict(
@@ -261,6 +296,8 @@ class WorldCollision(WorldCollisionConfig):
if config is not None:
WorldCollisionConfig.__init__(self, **vars(config))
self.collision_types = {} # Use this dictionary to store collision types
self._cache_voxelization = None
self._cache_voxelization_collision_buffer = None
def load_collision_model(self, world_model: WorldConfig):
raise NotImplementedError
@@ -273,6 +310,8 @@ class WorldCollision(WorldCollisionConfig):
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
"""
Computes the signed distance via analytic function
@@ -310,6 +349,7 @@ class WorldCollision(WorldCollisionConfig):
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
):
raise NotImplementedError
@@ -338,6 +378,118 @@ class WorldCollision(WorldCollisionConfig):
):
raise NotImplementedError
def get_voxels_in_bounding_box(
self,
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
voxel_size: float = 0.02,
) -> Union[List[Cuboid], torch.Tensor]:
new_grid = self.get_occupancy_in_bounding_box(cuboid, voxel_size)
occupied = new_grid.get_occupied_voxels(0.0)
return occupied
def clear_voxelization_cache(self):
self._cache_voxelization = None
def update_cache_voxelization(self, new_grid: VoxelGrid):
if (
self._cache_voxelization is None
or self._cache_voxelization.voxel_size != new_grid.voxel_size
or self._cache_voxelization.dims != new_grid.dims
):
self._cache_voxelization = new_grid
self._cache_voxelization.xyzr_tensor = self._cache_voxelization.create_xyzr_tensor(
transform_to_origin=True, tensor_args=self.tensor_args
)
self._cache_voxelization_collision_buffer = CollisionQueryBuffer()
xyzr = self._cache_voxelization.xyzr_tensor.view(-1, 1, 1, 4)
self._cache_voxelization_collision_buffer.update_buffer_shape(
xyzr.shape,
self.tensor_args,
self.collision_types,
)
def get_occupancy_in_bounding_box(
self,
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
voxel_size: float = 0.02,
) -> VoxelGrid:
new_grid = VoxelGrid(
name=cuboid.name, dims=cuboid.dims, pose=cuboid.pose, voxel_size=voxel_size
)
self.update_cache_voxelization(new_grid)
xyzr = self._cache_voxelization.xyzr_tensor
xyzr = xyzr.view(-1, 1, 1, 4)
weight = self.tensor_args.to_device([1.0])
act_distance = self.tensor_args.to_device([0.0])
d_sph = self.get_sphere_collision(
xyzr,
self._cache_voxelization_collision_buffer,
weight,
act_distance,
)
d_sph = d_sph.reshape(-1)
new_grid.xyzr_tensor = self._cache_voxelization.xyzr_tensor
new_grid.feature_tensor = d_sph
return new_grid
def get_esdf_in_bounding_box(
self,
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
voxel_size: float = 0.02,
dtype=torch.float32,
) -> VoxelGrid:
new_grid = VoxelGrid(
name=cuboid.name,
dims=cuboid.dims,
pose=cuboid.pose,
voxel_size=voxel_size,
feature_dtype=dtype,
)
self.update_cache_voxelization(new_grid)
xyzr = self._cache_voxelization.xyzr_tensor
voxel_shape = xyzr.shape
xyzr = xyzr.view(-1, 1, 1, 4)
weight = self.tensor_args.to_device([1.0])
d_sph = self.get_sphere_distance(
xyzr,
self._cache_voxelization_collision_buffer,
weight,
self.max_distance,
sum_collisions=False,
compute_esdf=True,
)
d_sph = d_sph.reshape(-1)
voxel_grid = self._cache_voxelization
voxel_grid.feature_tensor = d_sph
return voxel_grid
def get_mesh_in_bounding_box(
self,
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
voxel_size: float = 0.02,
) -> Mesh:
voxels = self.get_voxels_in_bounding_box(cuboid, voxel_size)
# voxels = voxels.cpu().numpy()
# cuboids = [Cuboid(name="c_"+str(x), pose=[voxels[x,0],voxels[x,1],voxels[x,2], 1,0,0,0], dims=[voxel_size, voxel_size, voxel_size]) for x in range(voxels.shape[0])]
# mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0]
mesh = Mesh.from_pointcloud(
voxels[:, :3].detach().cpu().numpy(),
pitch=voxel_size * 1.1,
)
return mesh
class WorldPrimitiveCollision(WorldCollision):
"""World Oriented Bounding Box representation object
@@ -354,6 +506,7 @@ class WorldPrimitiveCollision(WorldCollision):
self._env_n_obbs = None
self._env_obbs_names = None
self._init_cache()
if self.world_model is not None:
if isinstance(self.world_model, list):
self.load_batch_collision_model(self.world_model)
@@ -656,6 +809,8 @@ class WorldPrimitiveCollision(WorldCollision):
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
raise ValueError("Primitive Collision has no obstacles")
@@ -673,6 +828,7 @@ class WorldPrimitiveCollision(WorldCollision):
collision_query_buffer.primitive_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self._cube_tensor_list[0],
self._cube_tensor_list[0],
self._cube_tensor_list[1],
@@ -687,6 +843,8 @@ class WorldPrimitiveCollision(WorldCollision):
True,
use_batch_env,
return_loss,
sum_collisions,
compute_esdf,
)
return dist
@@ -699,6 +857,7 @@ class WorldPrimitiveCollision(WorldCollision):
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
**kwargs,
):
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
raise ValueError("Primitive Collision has no obstacles")
@@ -717,6 +876,7 @@ class WorldPrimitiveCollision(WorldCollision):
collision_query_buffer.primitive_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self._cube_tensor_list[0],
self._cube_tensor_list[0],
self._cube_tensor_list[1],
@@ -728,7 +888,7 @@ class WorldPrimitiveCollision(WorldCollision):
h,
n,
query_sphere.requires_grad,
False,
True,
use_batch_env,
return_loss,
)
@@ -745,6 +905,7 @@ class WorldPrimitiveCollision(WorldCollision):
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
sum_collisions: bool = True,
):
"""
Computes the signed distance via analytic function
@@ -784,6 +945,7 @@ class WorldPrimitiveCollision(WorldCollision):
True,
use_batch_env,
return_loss,
sum_collisions,
)
return dist
@@ -836,7 +998,7 @@ class WorldPrimitiveCollision(WorldCollision):
sweep_steps,
enable_speed_metric,
query_sphere.requires_grad,
False,
True,
use_batch_env,
return_loss,
)
@@ -845,70 +1007,5 @@ class WorldPrimitiveCollision(WorldCollision):
def clear_cache(self):
if self._cube_tensor_list is not None:
self._cube_tensor_list[2][:] = 1
self._cube_tensor_list[2][:] = 0
self._env_n_obbs[:] = 0
def get_voxels_in_bounding_box(
self,
cuboid: Cuboid,
voxel_size: float = 0.02,
) -> Union[List[Cuboid], torch.Tensor]:
bounds = cuboid.dims
low = [-bounds[0], -bounds[1], -bounds[2]]
high = [bounds[0], bounds[1], bounds[2]]
trange = [h - l for l, h in zip(low, high)]
x = torch.linspace(
-bounds[0], bounds[0], int(trange[0] // voxel_size) + 1, device=self.tensor_args.device
)
y = torch.linspace(
-bounds[1], bounds[1], int(trange[1] // voxel_size) + 1, device=self.tensor_args.device
)
z = torch.linspace(
-bounds[2], bounds[2], int(trange[2] // voxel_size) + 1, device=self.tensor_args.device
)
w, l, h = x.shape[0], y.shape[0], z.shape[0]
xyz = (
torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3)
)
pose = Pose.from_list(cuboid.pose, tensor_args=self.tensor_args)
xyz = pose.transform_points(xyz.contiguous())
r = torch.zeros_like(xyz[:, 0:1]) + voxel_size
xyzr = torch.cat([xyz, r], dim=1)
xyzr = xyzr.reshape(-1, 1, 1, 4)
collision_buffer = CollisionQueryBuffer()
collision_buffer.update_buffer_shape(
xyzr.shape,
self.tensor_args,
self.collision_types,
)
weight = self.tensor_args.to_device([1.0])
act_distance = self.tensor_args.to_device([0.0])
d_sph = self.get_sphere_collision(
xyzr,
collision_buffer,
weight,
act_distance,
)
d_sph = d_sph.reshape(-1)
xyzr = xyzr.reshape(-1, 4)
# get occupied voxels:
occupied = xyzr[d_sph > 0.0]
return occupied
def get_mesh_in_bounding_box(
self,
cuboid: Cuboid,
voxel_size: float = 0.02,
) -> Mesh:
voxels = self.get_voxels_in_bounding_box(cuboid, voxel_size)
# voxels = voxels.cpu().numpy()
# cuboids = [Cuboid(name="c_"+str(x), pose=[voxels[x,0],voxels[x,1],voxels[x,2], 1,0,0,0], dims=[voxel_size, voxel_size, voxel_size]) for x in range(voxels.shape[0])]
# mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0]
mesh = Mesh.from_pointcloud(
voxels[:, :3].detach().cpu().numpy(),
pitch=voxel_size * 1.1,
)
return mesh

View File

@@ -176,6 +176,8 @@ class WorldBloxCollision(WorldMeshCollision):
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
if "blox" not in self.collision_types or not self.collision_types["blox"]:
return super().get_sphere_distance(
@@ -185,6 +187,8 @@ class WorldBloxCollision(WorldMeshCollision):
activation_distance,
env_query_idx,
return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
d = self._get_blox_sdf(
@@ -205,8 +209,13 @@ class WorldBloxCollision(WorldMeshCollision):
activation_distance,
env_query_idx,
return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
d = d + d_base
if compute_esdf:
d = torch.maximum(d, d_base)
else:
d = d + d_base
return d
@@ -262,6 +271,7 @@ class WorldBloxCollision(WorldMeshCollision):
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
):
if "blox" not in self.collision_types or not self.collision_types["blox"]:
return super().get_swept_sphere_distance(
@@ -274,6 +284,7 @@ class WorldBloxCollision(WorldMeshCollision):
enable_speed_metric,
env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
d = self._get_blox_swept_sdf(
@@ -301,6 +312,7 @@ class WorldBloxCollision(WorldMeshCollision):
enable_speed_metric,
env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
d = d + d_base

View File

@@ -89,6 +89,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
self._mesh_tensor_list[0][env_idx, :max_nmesh] = w_mid
self._mesh_tensor_list[1][env_idx, :max_nmesh, :7] = w_inv_pose
self._mesh_tensor_list[2][env_idx, :max_nmesh] = 1
self._mesh_tensor_list[2][env_idx, max_nmesh:] = 0
self._env_mesh_names[env_idx][:max_nmesh] = name_list
self._env_n_mesh[env_idx] = max_nmesh
@@ -355,6 +356,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance: torch.Tensor,
env_query_idx=None,
return_loss=False,
compute_esdf=False,
):
d = SdfMeshWarpPy.apply(
query_spheres,
@@ -370,6 +372,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
self.max_distance,
env_query_idx,
return_loss,
compute_esdf,
)
return d
@@ -397,9 +400,9 @@ class WorldMeshCollision(WorldPrimitiveCollision):
self._mesh_tensor_list[1],
self._mesh_tensor_list[2],
self._env_n_mesh,
self.max_distance,
sweep_steps,
enable_speed_metric,
self.max_distance,
env_query_idx,
return_loss,
)
@@ -413,6 +416,8 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
# TODO: if no mesh object exist, call primitive
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
@@ -423,6 +428,8 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
d = self._get_sdf(
@@ -432,6 +439,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
compute_esdf=compute_esdf,
)
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
@@ -443,8 +451,13 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
d_val = d.view(d_prim.shape) + d_prim
if compute_esdf:
d_val = torch.maximum(d.view(d_prim.shape), d_prim)
else:
d_val = d.view(d_prim.shape) + d_prim
return d_val
def get_sphere_collision(
@@ -455,6 +468,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance: torch.Tensor,
env_query_idx=None,
return_loss=False,
**kwargs,
):
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
return super().get_sphere_collision(
@@ -501,6 +515,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
):
# log_warn("Swept: Mesh + Primitive Collision Checking is experimental")
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
@@ -514,6 +529,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
d = self._get_swept_sdf(
@@ -540,6 +556,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
d_val = d.view(d_prim.shape) + d_prim
@@ -602,4 +619,11 @@ class WorldMeshCollision(WorldPrimitiveCollision):
self._wp_mesh_cache = {}
if self._mesh_tensor_list is not None:
self._mesh_tensor_list[2][:] = 0
if self._env_n_mesh is not None:
self._env_n_mesh[:] = 0
if self._env_mesh_names is not None:
self._env_mesh_names = [
[None for _ in range(self.cache["mesh"])] for _ in range(self.n_envs)
]
super().clear_cache()

View File

@@ -0,0 +1,699 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
import math
from typing import Any, Dict, List, Optional
# Third Party
import numpy as np
import torch
# CuRobo
from curobo.curobolib.geom import SdfSphereVoxel, SdfSweptSphereVoxel
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollisionConfig
from curobo.geom.sdf.world_mesh import WorldMeshCollision
from curobo.geom.types import VoxelGrid, WorldConfig
from curobo.types.math import Pose
from curobo.util.logger import log_error, log_info, log_warn
class WorldVoxelCollision(WorldMeshCollision):
"""Voxel grid representation of World, with each voxel containing Euclidean Signed Distance."""
def __init__(self, config: WorldCollisionConfig):
self._env_n_voxels = None
self._voxel_tensor_list = None
self._env_voxel_names = None
super().__init__(config)
def _init_cache(self):
if (
self.cache is not None
and "voxel" in self.cache
and self.cache["voxel"] not in [None, 0]
):
self._create_voxel_cache(self.cache["voxel"])
def _create_voxel_cache(self, voxel_cache: Dict[str, Any]):
n_layers = voxel_cache["layers"]
dims = voxel_cache["dims"]
voxel_size = voxel_cache["voxel_size"]
feature_dtype = voxel_cache["feature_dtype"]
n_voxels = int(
math.floor(dims[0] / voxel_size)
* math.floor(dims[1] / voxel_size)
* math.floor(dims[2] / voxel_size)
)
voxel_params = torch.zeros(
(self.n_envs, n_layers, 4),
dtype=self.tensor_args.dtype,
device=self.tensor_args.device,
)
voxel_pose = torch.zeros(
(self.n_envs, n_layers, 8),
dtype=self.tensor_args.dtype,
device=self.tensor_args.device,
)
voxel_pose[..., 3] = 1.0
voxel_enable = torch.zeros(
(self.n_envs, n_layers), dtype=torch.uint8, device=self.tensor_args.device
)
self._env_n_voxels = torch.zeros(
(self.n_envs), device=self.tensor_args.device, dtype=torch.int32
)
voxel_features = torch.zeros(
(self.n_envs, n_layers, n_voxels, 1),
device=self.tensor_args.device,
dtype=feature_dtype,
)
self._voxel_tensor_list = [voxel_params, voxel_pose, voxel_enable, voxel_features]
self.collision_types["voxel"] = True
self._env_voxel_names = [[None for _ in range(n_layers)] for _ in range(self.n_envs)]
def load_collision_model(
self, world_model: WorldConfig, env_idx=0, fix_cache_reference: bool = False
):
self._load_collision_model_in_cache(
world_model, env_idx, fix_cache_reference=fix_cache_reference
)
return super().load_collision_model(
world_model, env_idx=env_idx, fix_cache_reference=fix_cache_reference
)
def _load_collision_model_in_cache(
self, world_config: WorldConfig, env_idx: int = 0, fix_cache_reference: bool = False
):
"""TODO:
_extended_summary_
Args:
world_config: _description_
env_idx: _description_
fix_cache_reference: _description_
"""
voxel_objs = world_config.voxel
max_obs = len(voxel_objs)
self.world_model = world_config
if max_obs < 1:
log_info("No Voxel objs")
return
if self._voxel_tensor_list is None or self._voxel_tensor_list[0].shape[1] < max_obs:
if not fix_cache_reference:
log_info("Creating Voxel cache" + str(max_obs))
self._create_voxel_cache(
{
"layers": max_obs,
"dims": voxel_objs[0].dims,
"voxel_size": voxel_objs[0].voxel_size,
"feature_dtype": voxel_objs[0].feature_dtype,
}
)
else:
log_error("number of OBB is larger than collision cache, create larger cache.")
# load as a batch:
pose_batch = [c.pose for c in voxel_objs]
dims_batch = [c.dims for c in voxel_objs]
names_batch = [c.name for c in voxel_objs]
size_batch = [c.voxel_size for c in voxel_objs]
voxel_batch = self._batch_tensor_voxel(pose_batch, dims_batch, size_batch)
self._voxel_tensor_list[0][env_idx, :max_obs, :] = voxel_batch[0]
self._voxel_tensor_list[1][env_idx, :max_obs, :7] = voxel_batch[1]
self._voxel_tensor_list[2][env_idx, :max_obs] = 1 # enabling obstacle
self._voxel_tensor_list[2][env_idx, max_obs:] = 0 # disabling obstacle
# copy voxel grid features:
self._env_n_voxels[env_idx] = max_obs
self._env_voxel_names[env_idx][:max_obs] = names_batch
self.collision_types["voxel"] = True
def _batch_tensor_voxel(
self, pose: List[List[float]], dims: List[float], voxel_size: List[float]
):
w_T_b = Pose.from_batch_list(pose, tensor_args=self.tensor_args)
b_T_w = w_T_b.inverse()
dims_t = torch.as_tensor(
np.array(dims), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
size_t = torch.as_tensor(
np.array(voxel_size), device=self.tensor_args.device, dtype=self.tensor_args.dtype
).unsqueeze(-1)
params_t = torch.cat([dims_t, size_t], dim=-1)
voxel_list = [params_t, b_T_w.get_pose_vector()]
return voxel_list
def load_batch_collision_model(self, world_config_list: List[WorldConfig]):
"""Load voxel grid for batched environments
_extended_summary_
Args:
world_config_list: _description_
Returns:
_description_
"""
log_error("Not Implemented")
# First find largest number of cuboid:
c_len = []
pose_batch = []
dims_batch = []
names_batch = []
vsize_batch = []
for i in world_config_list:
c = i.cuboid
if c is not None:
c_len.append(len(c))
pose_batch.extend([i.pose for i in c])
dims_batch.extend([i.dims for i in c])
names_batch.extend([i.name for i in c])
vsize_batch.extend([i.voxel_size for i in c])
else:
c_len.append(0)
max_obs = max(c_len)
if max_obs < 1:
log_warn("No obbs found")
return
# check if number of environments is same as config:
reset_buffers = False
if self._env_n_voxels is not None and len(world_config_list) != len(self._env_n_voxels):
log_warn(
"env_n_voxels is not same as world_config_list, reloading collision buffers (breaks CG)"
)
reset_buffers = True
self.n_envs = len(world_config_list)
self._env_n_voxels = torch.zeros(
(self.n_envs), device=self.tensor_args.device, dtype=torch.int32
)
if self._voxel_tensor_list is not None and self._voxel_tensor_list[0].shape[1] < max_obs:
log_warn(
"number of obbs is greater than buffer, reloading collision buffers (breaks CG)"
)
reset_buffers = True
# create cache if does not exist:
if self._voxel_tensor_list is None or reset_buffers:
log_info("Creating Obb cache" + str(max_obs))
self._create_obb_cache(max_obs)
# load obstacles:
## load data into gpu:
voxel_batch = self._batch_tensor_voxel(pose_batch, dims_batch, vsize_batch)
c_start = 0
for i in range(len(self._env_n_voxels)):
if c_len[i] > 0:
# load obb:
self._voxel_tensor_list[0][i, : c_len[i], :] = voxel_batch[0][
c_start : c_start + c_len[i]
]
self._voxel_tensor_list[1][i, : c_len[i], :7] = voxel_batch[1][
c_start : c_start + c_len[i]
]
self._voxel_tensor_list[2][i, : c_len[i]] = 1
self._env_voxel_names[i][: c_len[i]] = names_batch[c_start : c_start + c_len[i]]
self._voxel_tensor_list[2][i, c_len[i] :] = 0
c_start += c_len[i]
self._env_n_voxels[:] = torch.as_tensor(
c_len, dtype=torch.int32, device=self.tensor_args.device
)
self.collision_types["voxel"] = True
return super().load_batch_collision_model(world_config_list)
def enable_obstacle(
self,
name: str,
enable: bool = True,
env_idx: int = 0,
):
if self._env_voxel_names is not None and name in self._env_voxel_names[env_idx]:
self.enable_voxel(enable, name, None, env_idx)
else:
return super().enable_obstacle(name, enable, env_idx)
def enable_voxel(
self,
enable: bool = True,
name: Optional[str] = None,
env_obj_idx: Optional[torch.Tensor] = None,
env_idx: int = 0,
):
"""Update obstacle dimensions
Args:
obj_dims (torch.Tensor): [dim.x,dim.y, dim.z], give as [b,3]
obj_idx (torch.Tensor or int):
"""
if env_obj_idx is not None:
self._voxel_tensor_list[2][env_obj_idx] = int(enable) # enable == 1
else:
# find index of given name:
obs_idx = self.get_voxel_idx(name, env_idx)
self._voxel_tensor_list[2][env_idx, obs_idx] = int(enable)
def update_obstacle_pose(
self,
name: str,
w_obj_pose: Pose,
env_idx: int = 0,
):
if self._env_voxel_names is not None and name in self._env_voxel_names[env_idx]:
self.update_voxel_pose(name=name, w_obj_pose=w_obj_pose, env_idx=env_idx)
else:
log_error("obstacle not found in OBB world model: " + name)
def update_voxel_data(self, new_voxel: VoxelGrid, env_idx: int = 0):
obs_idx = self.get_voxel_idx(new_voxel.name, env_idx)
self._voxel_tensor_list[3][env_idx, obs_idx, :, :] = new_voxel.feature_tensor.view(
new_voxel.feature_tensor.shape[0], -1
).to(dtype=self._voxel_tensor_list[3].dtype)
self._voxel_tensor_list[0][env_idx, obs_idx, :3] = self.tensor_args.to_device(
new_voxel.dims
)
self._voxel_tensor_list[0][env_idx, obs_idx, 3] = new_voxel.voxel_size
self._voxel_tensor_list[1][env_idx, obs_idx, :7] = (
Pose.from_list(new_voxel.pose, self.tensor_args).inverse().get_pose_vector()
)
self._voxel_tensor_list[2][env_idx, obs_idx] = int(True)
def update_voxel_features(
self,
features: torch.Tensor,
name: Optional[str] = None,
env_obj_idx: Optional[torch.Tensor] = None,
env_idx: int = 0,
):
"""Update pose of a specific objects.
This also updates the signed distance grid to account for the updated object pose.
Args:
obj_w_pose: Pose
obj_idx:
"""
if env_obj_idx is not None:
self._voxel_tensor_list[3][env_obj_idx, :] = features.to(
dtype=self._voxel_tensor_list[3].dtype
)
else:
obs_idx = self.get_voxel_idx(name, env_idx)
self._voxel_tensor_list[3][env_idx, obs_idx, :] = features.to(
dtype=self._voxel_tensor_list[3].dtype
)
def update_voxel_pose(
self,
w_obj_pose: Optional[Pose] = None,
obj_w_pose: Optional[Pose] = None,
name: Optional[str] = None,
env_obj_idx: Optional[torch.Tensor] = None,
env_idx: int = 0,
):
"""Update pose of a specific objects.
This also updates the signed distance grid to account for the updated object pose.
Args:
obj_w_pose: Pose
obj_idx:
"""
obj_w_pose = self._get_obstacle_poses(w_obj_pose, obj_w_pose)
if env_obj_idx is not None:
self._voxel_tensor_list[1][env_obj_idx, :7] = obj_w_pose.get_pose_vector()
else:
obs_idx = self.get_voxel_idx(name, env_idx)
self._voxel_tensor_list[1][env_idx, obs_idx, :7] = obj_w_pose.get_pose_vector()
def get_voxel_idx(
self,
name: str,
env_idx: int = 0,
) -> int:
if name not in self._env_voxel_names[env_idx]:
log_error("Obstacle with name: " + name + " not found in current world", exc_info=True)
return self._env_voxel_names[env_idx].index(name)
def get_voxel_grid(
self,
name: str,
env_idx: int = 0,
):
obs_idx = self.get_voxel_idx(name, env_idx)
voxel_params = np.round(
self._voxel_tensor_list[0][env_idx, obs_idx, :].cpu().numpy().astype(np.float64), 6
).tolist()
voxel_pose = Pose(
position=self._voxel_tensor_list[1][env_idx, obs_idx, :3],
quaternion=self._voxel_tensor_list[1][env_idx, obs_idx, 3:7],
)
voxel_features = self._voxel_tensor_list[3][env_idx, obs_idx, :]
voxel_grid = VoxelGrid(
name=name,
dims=voxel_params[:3],
pose=voxel_pose.to_list(),
voxel_size=voxel_params[3],
feature_tensor=voxel_features,
)
return voxel_grid
def get_sphere_distance(
self,
query_sphere,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
return super().get_sphere_distance(
query_sphere,
collision_query_buffer,
weight,
activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
b, h, n, _ = query_sphere.shape # This can be read from collision query buffer
use_batch_env = True
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
dist = SdfSphereVoxel.apply(
query_sphere,
collision_query_buffer.voxel_collision_buffer.distance_buffer,
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
self._voxel_tensor_list[0].shape[1],
b,
h,
n,
query_sphere.requires_grad,
True,
use_batch_env,
return_loss,
sum_collisions,
compute_esdf,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_sphere_distance(
query_sphere,
collision_query_buffer,
weight=weight,
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
if compute_esdf:
d_val = torch.maximum(dist.view(d_prim.shape), d_prim)
else:
d_val = d_val.view(d_prim.shape) + d_prim
return d_val
def get_sphere_collision(
self,
query_sphere,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
**kwargs,
):
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
return super().get_sphere_collision(
query_sphere,
collision_query_buffer,
weight,
activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
)
if return_loss:
raise ValueError("cannot return loss for classification, use get_sphere_distance")
b, h, n, _ = query_sphere.shape
use_batch_env = True
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
dist = SdfSphereVoxel.apply(
query_sphere,
collision_query_buffer.voxel_collision_buffer.distance_buffer,
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
self._voxel_tensor_list[0].shape[1],
b,
h,
n,
query_sphere.requires_grad,
False,
use_batch_env,
False,
True,
False,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_sphere_collision(
query_sphere,
collision_query_buffer,
weight,
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
)
d_val = dist.view(d_prim.shape) + d_prim
return d_val
def get_swept_sphere_distance(
self,
query_sphere,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
speed_dt: torch.Tensor,
sweep_steps: int,
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
sum_collisions: bool = True,
):
"""
Computes the signed distance via analytic function
Args:
tensor_sphere: b, n, 4
"""
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
return super().get_swept_sphere_distance(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
b, h, n, _ = query_sphere.shape
use_batch_env = True
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
dist = SdfSweptSphereVoxel.apply(
query_sphere,
collision_query_buffer.voxel_collision_buffer.distance_buffer,
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
speed_dt,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
self._voxel_tensor_list[0].shape[1],
b,
h,
n,
sweep_steps,
enable_speed_metric,
query_sphere.requires_grad,
True,
use_batch_env,
return_loss,
sum_collisions,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_swept_sphere_distance(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
d_val = dist.view(d_prim.shape) + d_prim
return d_val
def get_swept_sphere_collision(
self,
query_sphere,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
speed_dt: torch.Tensor,
sweep_steps: int,
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
):
"""
Computes the signed distance via analytic function
Args:
tensor_sphere: b, n, 4
"""
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
return super().get_swept_sphere_collision(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
)
if return_loss:
raise ValueError("cannot return loss for classify, use get_swept_sphere_distance")
b, h, n, _ = query_sphere.shape
use_batch_env = True
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
dist = SdfSweptSphereVoxel.apply(
query_sphere,
collision_query_buffer.voxel_collision_buffer.distance_buffer,
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
speed_dt,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
self._voxel_tensor_list[0].shape[1],
b,
h,
n,
sweep_steps,
enable_speed_metric,
query_sphere.requires_grad,
False,
use_batch_env,
return_loss,
True,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_swept_sphere_collision(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
)
d_val = dist.view(d_prim.shape) + d_prim
return d_val
def clear_cache(self):
if self._voxel_tensor_list is not None:
self._voxel_tensor_list[2][:] = 0
self._voxel_tensor_list[-1][:] = -1.0 * self.max_distance
self._env_n_voxels[:] = 0

View File

@@ -18,6 +18,7 @@ import warp as wp
# CuRobo
from curobo.curobolib.kinematics import rotation_matrix_to_quaternion
from curobo.util.logger import log_error
from curobo.util.torch_utils import get_torch_jit_decorator
from curobo.util.warp import init_warp
@@ -27,11 +28,11 @@ def transform_points(
if out_points is None:
out_points = torch.zeros((points.shape[0], 3), device=points.device, dtype=points.dtype)
if out_gp is None:
out_gp = torch.zeros((position.shape[0], 3), device=position.device)
out_gp = torch.zeros((position.shape[0], 3), device=position.device, dtype=points.dtype)
if out_gq is None:
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device)
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device, dtype=points.dtype)
if out_gpt is None:
out_gpt = torch.zeros((points.shape[0], 3), device=position.device)
out_gpt = torch.zeros((points.shape[0], 3), device=position.device, dtype=points.dtype)
out_points = TransformPoint.apply(
position, quaternion, points, out_points, out_gp, out_gq, out_gpt
)
@@ -46,18 +47,20 @@ def batch_transform_points(
(points.shape[0], points.shape[1], 3), device=points.device, dtype=points.dtype
)
if out_gp is None:
out_gp = torch.zeros((position.shape[0], 3), device=position.device)
out_gp = torch.zeros((position.shape[0], 3), device=position.device, dtype=points.dtype)
if out_gq is None:
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device)
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device, dtype=points.dtype)
if out_gpt is None:
out_gpt = torch.zeros((points.shape[0], points.shape[1], 3), device=position.device)
out_gpt = torch.zeros(
(points.shape[0], points.shape[1], 3), device=position.device, dtype=points.dtype
)
out_points = BatchTransformPoint.apply(
position, quaternion, points, out_points, out_gp, out_gq, out_gpt
)
return out_points
@torch.jit.script
@get_torch_jit_decorator()
def get_inv_transform(w_rot_c, w_trans_c):
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
c_rot_w = w_rot_c.transpose(-1, -2)
@@ -65,7 +68,7 @@ def get_inv_transform(w_rot_c, w_trans_c):
return c_rot_w, c_trans_w
@torch.jit.script
@get_torch_jit_decorator()
def transform_point_inverse(point, rot, trans):
# type: (Tensor, Tensor, Tensor) -> Tensor

View File

@@ -11,6 +11,7 @@
from __future__ import annotations
# Standard Library
import math
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Union
@@ -564,6 +565,68 @@ class PointCloud(Obstacle):
return new_spheres
@dataclass
class VoxelGrid(Obstacle):
dims: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0])
voxel_size: float = 0.02 # meters
feature_tensor: Optional[torch.Tensor] = None
xyzr_tensor: Optional[torch.Tensor] = None
feature_dtype: torch.dtype = torch.float32
def __post_init__(self):
if self.feature_tensor is not None:
self.feature_dtype = self.feature_tensor.dtype
def create_xyzr_tensor(
self, transform_to_origin: bool = False, tensor_args: TensorDeviceType = TensorDeviceType()
):
bounds = self.dims
low = [-bounds[0] / 2, -bounds[1] / 2, -bounds[2] / 2]
high = [bounds[0] / 2, bounds[1] / 2, bounds[2] / 2]
trange = [h - l for l, h in zip(low, high)]
x = torch.linspace(
low[0], high[0], int(math.floor(trange[0] / self.voxel_size)), device=tensor_args.device
)
y = torch.linspace(
low[1], high[1], int(math.floor(trange[1] / self.voxel_size)), device=tensor_args.device
)
z = torch.linspace(
low[2], high[2], int(math.floor(trange[2] / self.voxel_size)), device=tensor_args.device
)
w, l, h = x.shape[0], y.shape[0], z.shape[0]
xyz = (
torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3)
)
if transform_to_origin:
pose = Pose.from_list(self.pose, tensor_args=tensor_args)
xyz = pose.transform_points(xyz.contiguous())
r = torch.zeros_like(xyz[:, 0:1]) + (self.voxel_size * 0.5)
xyzr = torch.cat([xyz, r], dim=1)
return xyzr
def get_occupied_voxels(self, feature_threshold: Optional[float] = None):
if feature_threshold is None:
feature_threshold = -1.0 * self.voxel_size
if self.xyzr_tensor is None or self.feature_tensor is None:
log_error("Feature tensor or xyzr tensor is empty")
xyzr = self.xyzr_tensor.clone()
xyzr[:, 3] = self.feature_tensor
occupied = xyzr[self.feature_tensor > feature_threshold]
return occupied
def clone(self):
return VoxelGrid(
name=self.name,
pose=self.pose.copy(),
dims=self.dims.copy(),
feature_tensor=self.feature_tensor.clone() if self.feature_tensor is not None else None,
xyzr_tensor=self.xyzr_tensor.clone() if self.xyzr_tensor is not None else None,
feature_dtype=self.feature_dtype,
voxel_size=self.voxel_size,
)
@dataclass
class WorldConfig(Sequence):
"""Representation of World for use in CuRobo."""
@@ -586,25 +649,13 @@ class WorldConfig(Sequence):
#: BloxMap obstacle.
blox: Optional[List[BloxMap]] = None
voxel: Optional[List[VoxelGrid]] = None
#: List of all obstacles in world.
objects: Optional[List[Obstacle]] = None
def __post_init__(self):
# create objects list:
if self.objects is None:
self.objects = []
if self.sphere is not None:
self.objects += self.sphere
if self.cuboid is not None:
self.objects += self.cuboid
if self.capsule is not None:
self.objects += self.capsule
if self.mesh is not None:
self.objects += self.mesh
if self.blox is not None:
self.objects += self.blox
if self.cylinder is not None:
self.objects += self.cylinder
if self.sphere is None:
self.sphere = []
if self.cuboid is None:
@@ -617,6 +668,18 @@ class WorldConfig(Sequence):
self.cylinder = []
if self.blox is None:
self.blox = []
if self.voxel is None:
self.voxel = []
if self.objects is None:
self.objects = (
self.sphere
+ self.cuboid
+ self.capsule
+ self.mesh
+ self.cylinder
+ self.blox
+ self.voxel
)
def __len__(self):
return len(self.objects)
@@ -632,6 +695,7 @@ class WorldConfig(Sequence):
capsule=self.capsule.copy() if self.capsule is not None else None,
cylinder=self.cylinder.copy() if self.cylinder is not None else None,
blox=self.blox.copy() if self.blox is not None else None,
voxel=self.voxel.copy() if self.voxel is not None else None,
)
@staticmethod
@@ -642,6 +706,7 @@ class WorldConfig(Sequence):
mesh = None
blox = None
cylinder = None
voxel = None
# load yaml:
if "cuboid" in data_dict.keys():
cuboid = [Cuboid(name=x, **data_dict["cuboid"][x]) for x in data_dict["cuboid"]]
@@ -655,6 +720,8 @@ class WorldConfig(Sequence):
cylinder = [Cylinder(name=x, **data_dict["cylinder"][x]) for x in data_dict["cylinder"]]
if "blox" in data_dict.keys():
blox = [BloxMap(name=x, **data_dict["blox"][x]) for x in data_dict["blox"]]
if "voxel" in data_dict.keys():
voxel = [VoxelGrid(name=x, **data_dict["voxel"][x]) for x in data_dict["voxel"]]
return WorldConfig(
cuboid=cuboid,
@@ -663,6 +730,7 @@ class WorldConfig(Sequence):
cylinder=cylinder,
mesh=mesh,
blox=blox,
voxel=voxel,
)
# load world config as obbs: convert all types to obbs
@@ -688,6 +756,10 @@ class WorldConfig(Sequence):
if current_world.mesh is not None and len(current_world.mesh) > 0:
mesh_obb = [x.get_cuboid() for x in current_world.mesh]
if current_world.voxel is not None and len(current_world.voxel) > 0:
log_error("VoxelGrid cannot be converted to obb world")
return WorldConfig(
cuboid=cuboid_obb + sphere_obb + capsule_obb + cylinder_obb + mesh_obb + blox_obb
)
@@ -714,6 +786,8 @@ class WorldConfig(Sequence):
for i in range(len(current_world.blox)):
if current_world.blox[i].mesh is not None:
blox_obb.append(current_world.blox[i].get_mesh(process=process))
if current_world.voxel is not None and len(current_world.voxel) > 0:
log_error("VoxelGrid cannot be converted to mesh world")
return WorldConfig(
mesh=current_world.mesh
@@ -750,6 +824,7 @@ class WorldConfig(Sequence):
return WorldConfig(
mesh=current_world.mesh + sphere_obb + capsule_obb + cylinder_obb + blox_obb,
cuboid=cuboid_obb,
voxel=current_world.voxel,
)
@staticmethod
@@ -822,6 +897,8 @@ class WorldConfig(Sequence):
self.cylinder.append(obstacle)
elif isinstance(obstacle, Capsule):
self.capsule.append(obstacle)
elif isinstance(obstacle, VoxelGrid):
self.voxel.append(obstacle)
else:
ValueError("Obstacle type not supported")
self.objects.append(obstacle)