Significantly improved convergence for mesh and cuboid, new ESDF collision.
This commit is contained in:
@@ -12,8 +12,11 @@
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
@torch.jit.script
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def project_depth_to_pointcloud(depth_image: torch.Tensor, intrinsics_matrix: torch.Tensor):
|
||||
"""Projects numpy depth image to point cloud.
|
||||
|
||||
@@ -43,7 +46,7 @@ def project_depth_to_pointcloud(depth_image: torch.Tensor, intrinsics_matrix: to
|
||||
return raw_pc
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor):
|
||||
"""Projects numpy depth image to point cloud.
|
||||
|
||||
@@ -54,10 +57,10 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
|
||||
Returns:
|
||||
array of float (h, w, 3)
|
||||
"""
|
||||
fx = intrinsics_matrix[:, 0, 0]
|
||||
fy = intrinsics_matrix[:, 1, 1]
|
||||
cx = intrinsics_matrix[:, 0, 2]
|
||||
cy = intrinsics_matrix[:, 1, 2]
|
||||
fx = intrinsics_matrix[:, 0:1, 0:1]
|
||||
fy = intrinsics_matrix[:, 1:2, 1:2]
|
||||
cx = intrinsics_matrix[:, 0:1, 2:3]
|
||||
cy = intrinsics_matrix[:, 1:2, 2:3]
|
||||
|
||||
input_x = torch.arange(width, dtype=torch.float32, device=intrinsics_matrix.device)
|
||||
input_y = torch.arange(height, dtype=torch.float32, device=intrinsics_matrix.device)
|
||||
@@ -73,7 +76,6 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
|
||||
device=intrinsics_matrix.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
output_x = (input_x - cx) / fx
|
||||
output_y = (input_y - cy) / fy
|
||||
|
||||
@@ -84,7 +86,7 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
|
||||
return rays
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def project_pointcloud_to_depth(
|
||||
pointcloud: torch.Tensor,
|
||||
output_image: torch.Tensor,
|
||||
@@ -106,7 +108,7 @@ def project_pointcloud_to_depth(
|
||||
return output_image
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def project_depth_using_rays(
|
||||
depth_image: torch.Tensor, rays: torch.Tensor, filter_origin: bool = False
|
||||
):
|
||||
|
||||
@@ -11,8 +11,10 @@
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# @torch.jit.script
|
||||
|
||||
# @get_torch_jit_decorator()
|
||||
def lookup_distance(pt, dist_matrix_flat, num_voxels):
|
||||
# flatten:
|
||||
ind_pt = (
|
||||
@@ -22,7 +24,7 @@ def lookup_distance(pt, dist_matrix_flat, num_voxels):
|
||||
return dist
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def compute_sdf_gradient(pt, dist_matrix_flat, num_voxels, dist):
|
||||
grad_l = []
|
||||
for i in range(3): # x,y,z
|
||||
|
||||
@@ -30,6 +30,11 @@ def create_collision_checker(config: WorldCollisionConfig):
|
||||
from curobo.geom.sdf.world_mesh import WorldMeshCollision
|
||||
|
||||
return WorldMeshCollision(config)
|
||||
elif config.checker_type == CollisionCheckerType.VOXEL:
|
||||
# CuRobo
|
||||
from curobo.geom.sdf.world_voxel import WorldVoxelCollision
|
||||
|
||||
return WorldVoxelCollision(config)
|
||||
else:
|
||||
log_error("Not implemented", exc_info=True)
|
||||
log_error("Unknown Collision Checker type: " + config.checker_type, exc_info=True)
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -16,284 +16,6 @@ import warp as wp
|
||||
wp.set_module_options({"fast_math": False})
|
||||
|
||||
|
||||
# create warp kernels:
|
||||
@wp.kernel
|
||||
def get_swept_closest_pt(
|
||||
pt: wp.array(dtype=wp.vec4),
|
||||
distance: wp.array(dtype=wp.float32), # this stores the output cost
|
||||
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
|
||||
sparsity_idx: wp.array(dtype=wp.uint8),
|
||||
weight: wp.array(dtype=wp.float32),
|
||||
activation_distance: wp.array(dtype=wp.float32), # eta threshold
|
||||
speed_dt: wp.array(dtype=wp.float32),
|
||||
mesh: wp.array(dtype=wp.uint64),
|
||||
mesh_pose: wp.array(dtype=wp.float32),
|
||||
mesh_enable: wp.array(dtype=wp.uint8),
|
||||
n_env_mesh: wp.array(dtype=wp.int32),
|
||||
max_dist: wp.float32,
|
||||
write_grad: wp.uint8,
|
||||
batch_size: wp.int32,
|
||||
horizon: wp.int32,
|
||||
nspheres: wp.int32,
|
||||
max_nmesh: wp.int32,
|
||||
sweep_steps: wp.uint8,
|
||||
enable_speed_metric: wp.uint8,
|
||||
):
|
||||
# we launch nspheres kernels
|
||||
# compute gradient here and return
|
||||
# distance is negative outside and positive inside
|
||||
tid = int(0)
|
||||
tid = wp.tid()
|
||||
|
||||
b_idx = int(0)
|
||||
h_idx = int(0)
|
||||
sph_idx = int(0)
|
||||
# read horizon
|
||||
eta = float(0.0) # 5cm buffer
|
||||
dt = float(1.0)
|
||||
|
||||
b_idx = tid / (horizon * nspheres)
|
||||
|
||||
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
|
||||
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
|
||||
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
|
||||
return
|
||||
|
||||
n_mesh = int(0)
|
||||
# $wp.printf("%d, %d, %d, %d \n", tid, b_idx, h_idx, sph_idx)
|
||||
# read sphere
|
||||
sphere_0_distance = float(0.0)
|
||||
sphere_2_distance = float(0.0)
|
||||
|
||||
sphere_0 = wp.vec3(0.0)
|
||||
sphere_2 = wp.vec3(0.0)
|
||||
sphere_int = wp.vec3(0.0)
|
||||
sphere_temp = wp.vec3(0.0)
|
||||
k0 = float(0.0)
|
||||
face_index = int(0)
|
||||
face_u = float(0.0)
|
||||
face_v = float(0.0)
|
||||
sign = float(0.0)
|
||||
dist = float(0.0)
|
||||
|
||||
uint_zero = wp.uint8(0)
|
||||
uint_one = wp.uint8(1)
|
||||
cl_pt = wp.vec3(0.0)
|
||||
local_pt = wp.vec3(0.0)
|
||||
in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
|
||||
in_rad = in_sphere[3]
|
||||
if in_rad < 0.0:
|
||||
distance[tid] = 0.0
|
||||
if write_grad == 1 and sparsity_idx[tid] == uint_one:
|
||||
sparsity_idx[tid] = uint_zero
|
||||
closest_pt[tid * 4] = 0.0
|
||||
closest_pt[tid * 4 + 1] = 0.0
|
||||
closest_pt[tid * 4 + 2] = 0.0
|
||||
|
||||
return
|
||||
eta = activation_distance[0]
|
||||
dt = speed_dt[0]
|
||||
in_rad += eta
|
||||
|
||||
max_dist_buffer = float(0.0)
|
||||
max_dist_buffer = max_dist
|
||||
if in_rad > max_dist_buffer:
|
||||
max_dist_buffer += in_rad
|
||||
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
|
||||
|
||||
# read in sphere 0:
|
||||
# read in sphere 0:
|
||||
if h_idx > 0:
|
||||
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx - 1) * nspheres) + sph_idx]
|
||||
sphere_0 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
|
||||
sphere_0_distance = wp.length(sphere_0 - in_pt) / 2.0
|
||||
|
||||
if h_idx < horizon - 1:
|
||||
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx + 1) * nspheres) + sph_idx]
|
||||
sphere_2 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
|
||||
sphere_2_distance = wp.length(sphere_2 - in_pt) / 2.0
|
||||
|
||||
# read in sphere 2:
|
||||
closest_distance = float(0.0)
|
||||
closest_point = wp.vec3(0.0)
|
||||
i = int(0)
|
||||
dis_length = float(0.0)
|
||||
jump_distance = float(0.0)
|
||||
mid_distance = float(0.0)
|
||||
n_mesh = n_env_mesh[0]
|
||||
obj_position = wp.vec3()
|
||||
|
||||
while i < n_mesh:
|
||||
if mesh_enable[i] == uint_one:
|
||||
obj_position[0] = mesh_pose[i * 8 + 0]
|
||||
obj_position[1] = mesh_pose[i * 8 + 1]
|
||||
obj_position[2] = mesh_pose[i * 8 + 2]
|
||||
obj_quat = wp.quaternion(
|
||||
mesh_pose[i * 8 + 4],
|
||||
mesh_pose[i * 8 + 5],
|
||||
mesh_pose[i * 8 + 6],
|
||||
mesh_pose[i * 8 + 3],
|
||||
)
|
||||
|
||||
obj_w_pose = wp.transform(obj_position, obj_quat)
|
||||
obj_w_pose_t = wp.transform_inverse(obj_w_pose)
|
||||
# transform point to mesh frame:
|
||||
# mesh_pt = T_inverse @ w_pt
|
||||
local_pt = wp.transform_point(obj_w_pose, in_pt)
|
||||
|
||||
if wp.mesh_query_point(
|
||||
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - local_pt)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - local_pt) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
else:
|
||||
dist = -1.0 * dist
|
||||
else:
|
||||
dist = in_rad
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
|
||||
mid_distance = dist
|
||||
# transform sphere -1
|
||||
if h_idx > 0 and mid_distance < sphere_0_distance:
|
||||
jump_distance = mid_distance
|
||||
j = int(0)
|
||||
sphere_temp = wp.transform_point(obj_w_pose, sphere_0)
|
||||
while j < sweep_steps:
|
||||
k0 = (
|
||||
1.0 - 0.5 * jump_distance / sphere_0_distance
|
||||
) # dist could be greater than sphere_0_distance here?
|
||||
sphere_int = k0 * local_pt + ((1.0 - k0) * sphere_temp)
|
||||
if wp.mesh_query_point(
|
||||
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - sphere_int)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - sphere_int) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
jump_distance += dist
|
||||
else:
|
||||
dist = max(-dist - in_rad, in_rad)
|
||||
|
||||
jump_distance += dist
|
||||
else:
|
||||
jump_distance += max_dist_buffer
|
||||
j += 1
|
||||
if jump_distance >= sphere_0_distance:
|
||||
j = int(sweep_steps)
|
||||
# transform sphere -1
|
||||
if h_idx < horizon - 1 and mid_distance < sphere_2_distance:
|
||||
jump_distance = mid_distance
|
||||
j = int(0)
|
||||
sphere_temp = wp.transform_point(obj_w_pose, sphere_2)
|
||||
while j < sweep_steps:
|
||||
k0 = (
|
||||
1.0 - 0.5 * jump_distance / sphere_2_distance
|
||||
) # dist could be greater than sphere_0_distance here?
|
||||
sphere_int = k0 * local_pt + (1.0 - k0) * sphere_temp
|
||||
|
||||
if wp.mesh_query_point(
|
||||
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - sphere_int)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - sphere_int) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
|
||||
jump_distance += dist
|
||||
else:
|
||||
dist = max(-dist - in_rad, in_rad)
|
||||
|
||||
jump_distance += dist
|
||||
else:
|
||||
jump_distance += max_dist_buffer
|
||||
|
||||
j += 1
|
||||
if jump_distance >= sphere_2_distance:
|
||||
j = int(sweep_steps)
|
||||
|
||||
i += 1
|
||||
# return
|
||||
if closest_distance == 0:
|
||||
if sparsity_idx[tid] == uint_zero:
|
||||
return
|
||||
sparsity_idx[tid] = uint_zero
|
||||
distance[tid] = 0.0
|
||||
if write_grad == 1:
|
||||
closest_pt[tid * 4 + 0] = 0.0
|
||||
closest_pt[tid * 4 + 1] = 0.0
|
||||
closest_pt[tid * 4 + 2] = 0.0
|
||||
return
|
||||
|
||||
if enable_speed_metric == 1 and (h_idx > 0 and h_idx < horizon - 1):
|
||||
# calculate sphere velocity and acceleration:
|
||||
norm_vel_vec = wp.vec3(0.0)
|
||||
sph_acc_vec = wp.vec3(0.0)
|
||||
sph_vel = wp.float(0.0)
|
||||
|
||||
# use central difference
|
||||
norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
|
||||
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
|
||||
# norm_vel_vec = -1.0 * norm_vel_vec
|
||||
# sph_acc_vec = -1.0 * sph_acc_vec
|
||||
sph_vel = wp.length(norm_vel_vec)
|
||||
|
||||
norm_vel_vec = norm_vel_vec / sph_vel
|
||||
|
||||
orth_proj = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) - wp.outer(
|
||||
norm_vel_vec, norm_vel_vec
|
||||
)
|
||||
|
||||
curvature_vec = orth_proj * (sph_acc_vec / (sph_vel * sph_vel))
|
||||
|
||||
closest_point = sph_vel * ((orth_proj * closest_point) - closest_distance * curvature_vec)
|
||||
|
||||
closest_distance = sph_vel * closest_distance
|
||||
|
||||
distance[tid] = weight[0] * closest_distance
|
||||
sparsity_idx[tid] = uint_one
|
||||
if write_grad == 1:
|
||||
# compute gradient:
|
||||
if closest_distance > 0.0:
|
||||
closest_distance = weight[0]
|
||||
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
|
||||
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
|
||||
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
|
||||
|
||||
|
||||
@wp.kernel
|
||||
def get_swept_closest_pt_batch_env(
|
||||
pt: wp.array(dtype=wp.vec4),
|
||||
@@ -307,7 +29,7 @@ def get_swept_closest_pt_batch_env(
|
||||
mesh_pose: wp.array(dtype=wp.float32),
|
||||
mesh_enable: wp.array(dtype=wp.uint8),
|
||||
n_env_mesh: wp.array(dtype=wp.int32),
|
||||
max_dist: wp.float32,
|
||||
max_dist: wp.array(dtype=wp.float32),
|
||||
write_grad: wp.uint8,
|
||||
batch_size: wp.int32,
|
||||
horizon: wp.int32,
|
||||
@@ -316,6 +38,7 @@ def get_swept_closest_pt_batch_env(
|
||||
sweep_steps: wp.uint8,
|
||||
enable_speed_metric: wp.uint8,
|
||||
env_query_idx: wp.array(dtype=wp.int32),
|
||||
use_batch_env: wp.uint8,
|
||||
):
|
||||
# we launch nspheres kernels
|
||||
# compute gradient here and return
|
||||
@@ -357,6 +80,7 @@ def get_swept_closest_pt_batch_env(
|
||||
sign = float(0.0)
|
||||
dist = float(0.0)
|
||||
dist_metric = float(0.0)
|
||||
euclidean_distance = float(0.0)
|
||||
cl_pt = wp.vec3(0.0)
|
||||
local_pt = wp.vec3(0.0)
|
||||
in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
|
||||
@@ -374,7 +98,7 @@ def get_swept_closest_pt_batch_env(
|
||||
eta = activation_distance[0]
|
||||
in_rad += eta
|
||||
max_dist_buffer = float(0.0)
|
||||
max_dist_buffer = max_dist
|
||||
max_dist_buffer = max_dist[0]
|
||||
if (in_rad) > max_dist_buffer:
|
||||
max_dist_buffer += in_rad
|
||||
|
||||
@@ -396,7 +120,8 @@ def get_swept_closest_pt_batch_env(
|
||||
dis_length = float(0.0)
|
||||
jump_distance = float(0.0)
|
||||
mid_distance = float(0.0)
|
||||
env_idx = env_query_idx[b_idx]
|
||||
if use_batch_env:
|
||||
env_idx = env_query_idx[b_idx]
|
||||
i = max_nmesh * env_idx
|
||||
n_mesh = i + n_env_mesh[env_idx]
|
||||
obj_position = wp.vec3()
|
||||
@@ -423,26 +148,33 @@ def get_swept_closest_pt_batch_env(
|
||||
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - local_pt)
|
||||
delta = cl_pt - local_pt
|
||||
dis_length = wp.length(delta)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - local_pt) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
if dist == in_rad:
|
||||
cl_pt = sign * (delta) / (dist)
|
||||
else:
|
||||
cl_pt = sign * (delta) / dis_length
|
||||
euclidean_distance = dist
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
cl_pt = (1.0 / eta) * dist * cl_pt
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
else:
|
||||
dist = -1.0 * dist
|
||||
euclidean_distance = dist
|
||||
else:
|
||||
dist = max_dist_buffer
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
euclidean_distance = dist
|
||||
dist = max(euclidean_distance - in_rad, in_rad)
|
||||
|
||||
mid_distance = dist
|
||||
mid_distance = euclidean_distance
|
||||
# transform sphere -1
|
||||
if h_idx > 0 and mid_distance < sphere_0_distance:
|
||||
jump_distance = mid_distance
|
||||
@@ -457,24 +189,31 @@ def get_swept_closest_pt_batch_env(
|
||||
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - sphere_int)
|
||||
delta = cl_pt - sphere_int
|
||||
dis_length = wp.length(delta)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - sphere_int) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
if dist == in_rad:
|
||||
cl_pt = sign * (delta) / (dist)
|
||||
else:
|
||||
cl_pt = sign * (delta) / dis_length
|
||||
euclidean_distance = dist
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
cl_pt = (1.0 / eta) * dist * cl_pt
|
||||
|
||||
closest_distance += dist_metric
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
|
||||
closest_point += grad_vec
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
jump_distance += dist
|
||||
dist = max(euclidean_distance - in_rad, in_rad)
|
||||
jump_distance += euclidean_distance
|
||||
else:
|
||||
dist = max(-dist - in_rad, in_rad)
|
||||
jump_distance += dist
|
||||
euclidean_distance = dist
|
||||
jump_distance += euclidean_distance
|
||||
else:
|
||||
jump_distance += max_dist_buffer
|
||||
j += 1
|
||||
@@ -495,24 +234,30 @@ def get_swept_closest_pt_batch_env(
|
||||
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - sphere_int)
|
||||
delta = cl_pt - sphere_int
|
||||
dis_length = wp.length(delta)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - sphere_int) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
euclidean_distance = dist
|
||||
if dist == in_rad:
|
||||
cl_pt = sign * (delta) / (dist)
|
||||
else:
|
||||
cl_pt = sign * (delta) / dis_length
|
||||
# cl_pt = sign * (delta) / dis_length
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
cl_pt = (1.0 / eta) * dist * cl_pt
|
||||
closest_distance += dist_metric
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
|
||||
closest_point += grad_vec
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
dist = max(euclidean_distance - in_rad, in_rad)
|
||||
jump_distance += dist
|
||||
|
||||
else:
|
||||
dist = max(-dist - in_rad, in_rad)
|
||||
|
||||
jump_distance += dist
|
||||
else:
|
||||
jump_distance += max_dist_buffer
|
||||
@@ -542,179 +287,54 @@ def get_swept_closest_pt_batch_env(
|
||||
|
||||
# use central difference
|
||||
norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
|
||||
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
|
||||
# norm_vel_vec = -1.0 * norm_vel_vec
|
||||
# sph_acc_vec = -1.0 * sph_acc_vec
|
||||
sph_vel = wp.length(norm_vel_vec)
|
||||
if sph_vel > 1e-3:
|
||||
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
|
||||
|
||||
norm_vel_vec = norm_vel_vec / sph_vel
|
||||
norm_vel_vec = norm_vel_vec * (1.0 / sph_vel)
|
||||
|
||||
orth_proj = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) - wp.outer(
|
||||
norm_vel_vec, norm_vel_vec
|
||||
)
|
||||
curvature_vec = sph_acc_vec / (sph_vel * sph_vel)
|
||||
|
||||
curvature_vec = orth_proj * (sph_acc_vec / (sph_vel * sph_vel))
|
||||
orth_proj = wp.mat33(0.0)
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
orth_proj[i, j] = -1.0 * norm_vel_vec[i] * norm_vel_vec[j]
|
||||
|
||||
closest_point = sph_vel * ((orth_proj * closest_point) - closest_distance * curvature_vec)
|
||||
orth_proj[0, 0] = orth_proj[0, 0] + 1.0
|
||||
orth_proj[1, 1] = orth_proj[1, 1] + 1.0
|
||||
orth_proj[2, 2] = orth_proj[2, 2] + 1.0
|
||||
|
||||
closest_distance = sph_vel * closest_distance
|
||||
orth_curv = wp.vec3(
|
||||
0.0, 0.0, 0.0
|
||||
) # closest_distance * (orth_proj @ curvature_vec) #wp.matmul(orth_proj, curvature_vec)
|
||||
orth_pt = wp.vec3(0.0, 0.0, 0.0) # orth_proj @ closest_point
|
||||
for i in range(3):
|
||||
orth_pt[i] = (
|
||||
orth_proj[i, 0] * closest_point[0]
|
||||
+ orth_proj[i, 1] * closest_point[1]
|
||||
+ orth_proj[i, 2] * closest_point[2]
|
||||
)
|
||||
|
||||
orth_curv[i] = closest_distance * (
|
||||
orth_proj[i, 0] * curvature_vec[0]
|
||||
+ orth_proj[i, 1] * curvature_vec[1]
|
||||
+ orth_proj[i, 2] * curvature_vec[2]
|
||||
)
|
||||
|
||||
closest_point = sph_vel * (orth_pt - orth_curv)
|
||||
|
||||
closest_distance = sph_vel * closest_distance
|
||||
|
||||
distance[tid] = weight[0] * closest_distance
|
||||
sparsity_idx[tid] = uint_one
|
||||
if write_grad == 1:
|
||||
# compute gradient:
|
||||
if closest_distance > 0.0:
|
||||
closest_distance = weight[0]
|
||||
closest_distance = weight[0]
|
||||
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
|
||||
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
|
||||
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
|
||||
|
||||
|
||||
@wp.kernel
|
||||
def get_closest_pt(
|
||||
pt: wp.array(dtype=wp.vec4),
|
||||
distance: wp.array(dtype=wp.float32), # this stores the output cost
|
||||
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
|
||||
sparsity_idx: wp.array(dtype=wp.uint8),
|
||||
weight: wp.array(dtype=wp.float32),
|
||||
activation_distance: wp.array(dtype=wp.float32), # eta threshold
|
||||
mesh: wp.array(dtype=wp.uint64),
|
||||
mesh_pose: wp.array(dtype=wp.float32),
|
||||
mesh_enable: wp.array(dtype=wp.uint8),
|
||||
n_env_mesh: wp.array(dtype=wp.int32),
|
||||
max_dist: wp.float32,
|
||||
write_grad: wp.uint8,
|
||||
batch_size: wp.int32,
|
||||
horizon: wp.int32,
|
||||
nspheres: wp.int32,
|
||||
max_nmesh: wp.int32,
|
||||
):
|
||||
# we launch nspheres kernels
|
||||
# compute gradient here and return
|
||||
# distance is negative outside and positive inside
|
||||
tid = wp.tid()
|
||||
n_mesh = int(0)
|
||||
b_idx = int(0)
|
||||
h_idx = int(0)
|
||||
sph_idx = int(0)
|
||||
|
||||
# env_idx = int(0)
|
||||
|
||||
b_idx = tid / (horizon * nspheres)
|
||||
|
||||
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
|
||||
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
|
||||
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
|
||||
return
|
||||
|
||||
face_index = int(0)
|
||||
face_u = float(0.0)
|
||||
face_v = float(0.0)
|
||||
sign = float(0.0)
|
||||
dist = float(0.0)
|
||||
grad_vec = wp.vec3(0.0)
|
||||
eta = float(0.05)
|
||||
dist_metric = float(0.0)
|
||||
|
||||
cl_pt = wp.vec3(0.0)
|
||||
local_pt = wp.vec3(0.0)
|
||||
in_sphere = pt[tid]
|
||||
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
|
||||
in_rad = in_sphere[3]
|
||||
uint_zero = wp.uint8(0)
|
||||
uint_one = wp.uint8(1)
|
||||
if in_rad < 0.0:
|
||||
distance[tid] = 0.0
|
||||
if write_grad == 1 and sparsity_idx[tid] == uint_one:
|
||||
sparsity_idx[tid] = uint_zero
|
||||
closest_pt[tid * 4] = 0.0
|
||||
closest_pt[tid * 4 + 1] = 0.0
|
||||
closest_pt[tid * 4 + 2] = 0.0
|
||||
return
|
||||
|
||||
eta = activation_distance[0]
|
||||
in_rad += eta
|
||||
max_dist_buffer = float(0.0)
|
||||
max_dist_buffer = max_dist
|
||||
if in_rad > max_dist_buffer:
|
||||
max_dist_buffer += in_rad
|
||||
|
||||
# TODO: read vec4 and use first 3 for sphere position and last one for radius
|
||||
# in_pt = pt[tid]
|
||||
closest_distance = float(0.0)
|
||||
closest_point = wp.vec3(0.0)
|
||||
i = int(0)
|
||||
dis_length = float(0.0)
|
||||
# read env index:
|
||||
# env_idx = env_query_idx[b_idx]
|
||||
# read number of boxes in current environment:
|
||||
|
||||
# get start index
|
||||
i = int(0)
|
||||
n_mesh = n_env_mesh[0]
|
||||
obj_position = wp.vec3()
|
||||
|
||||
# mesh_idx = wp.uint64(0)
|
||||
while i < n_mesh:
|
||||
if mesh_enable[i] == uint_one:
|
||||
# transform point to mesh frame:
|
||||
# mesh_pt = T_inverse @ w_pt
|
||||
|
||||
# read object pose:
|
||||
obj_position[0] = mesh_pose[i * 8 + 0]
|
||||
obj_position[1] = mesh_pose[i * 8 + 1]
|
||||
obj_position[2] = mesh_pose[i * 8 + 2]
|
||||
obj_quat = wp.quaternion(
|
||||
mesh_pose[i * 8 + 4],
|
||||
mesh_pose[i * 8 + 5],
|
||||
mesh_pose[i * 8 + 6],
|
||||
mesh_pose[i * 8 + 3],
|
||||
)
|
||||
|
||||
obj_w_pose = wp.transform(obj_position, obj_quat)
|
||||
|
||||
local_pt = wp.transform_point(obj_w_pose, in_pt)
|
||||
# mesh_idx = mesh[i]
|
||||
if wp.mesh_query_point(
|
||||
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - local_pt)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - local_pt) / dis_length
|
||||
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
|
||||
i += 1
|
||||
|
||||
if closest_distance == 0:
|
||||
if sparsity_idx[tid] == uint_zero:
|
||||
return
|
||||
sparsity_idx[tid] = uint_zero
|
||||
distance[tid] = 0.0
|
||||
if write_grad == 1:
|
||||
closest_pt[tid * 4 + 0] = 0.0
|
||||
closest_pt[tid * 4 + 1] = 0.0
|
||||
closest_pt[tid * 4 + 2] = 0.0
|
||||
else:
|
||||
distance[tid] = weight[0] * closest_distance
|
||||
sparsity_idx[tid] = uint_one
|
||||
if write_grad == 1:
|
||||
# compute gradient:
|
||||
if closest_distance > 0.0:
|
||||
closest_distance = weight[0]
|
||||
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
|
||||
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
|
||||
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
|
||||
|
||||
|
||||
@wp.kernel
|
||||
def get_closest_pt_batch_env(
|
||||
pt: wp.array(dtype=wp.vec4),
|
||||
@@ -727,13 +347,15 @@ def get_closest_pt_batch_env(
|
||||
mesh_pose: wp.array(dtype=wp.float32),
|
||||
mesh_enable: wp.array(dtype=wp.uint8),
|
||||
n_env_mesh: wp.array(dtype=wp.int32),
|
||||
max_dist: wp.float32,
|
||||
max_dist: wp.array(dtype=wp.float32),
|
||||
write_grad: wp.uint8,
|
||||
batch_size: wp.int32,
|
||||
horizon: wp.int32,
|
||||
nspheres: wp.int32,
|
||||
max_nmesh: wp.int32,
|
||||
env_query_idx: wp.array(dtype=wp.int32),
|
||||
use_batch_env: wp.uint8,
|
||||
compute_esdf: wp.uint8,
|
||||
):
|
||||
# we launch nspheres kernels
|
||||
# compute gradient here and return
|
||||
@@ -779,8 +401,9 @@ def get_closest_pt_batch_env(
|
||||
|
||||
return
|
||||
eta = activation_distance[0]
|
||||
in_rad += eta
|
||||
max_dist_buffer = max_dist
|
||||
if compute_esdf != 1:
|
||||
in_rad += eta
|
||||
max_dist_buffer = max_dist[0]
|
||||
if (in_rad) > max_dist_buffer:
|
||||
max_dist_buffer += in_rad
|
||||
|
||||
@@ -791,7 +414,9 @@ def get_closest_pt_batch_env(
|
||||
dis_length = float(0.0)
|
||||
|
||||
# read env index:
|
||||
env_idx = env_query_idx[b_idx]
|
||||
if use_batch_env:
|
||||
env_idx = env_query_idx[b_idx]
|
||||
|
||||
# read number of boxes in current environment:
|
||||
|
||||
# get start index
|
||||
@@ -799,7 +424,9 @@ def get_closest_pt_batch_env(
|
||||
i = max_nmesh * env_idx
|
||||
n_mesh = i + n_env_mesh[env_idx]
|
||||
obj_position = wp.vec3()
|
||||
|
||||
max_dist_value = -1.0 * max_dist_buffer
|
||||
if compute_esdf == 1:
|
||||
closest_distance = max_dist_value
|
||||
while i < n_mesh:
|
||||
if mesh_enable[i] == uint_one:
|
||||
# transform point to mesh frame:
|
||||
@@ -822,21 +449,39 @@ def get_closest_pt_batch_env(
|
||||
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - local_pt)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - local_pt) / dis_length
|
||||
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
delta = cl_pt - local_pt
|
||||
dis_length = wp.length(delta)
|
||||
dist = -1.0 * dis_length * sign
|
||||
|
||||
if compute_esdf == 1:
|
||||
if dist > max_dist_value:
|
||||
max_dist_value = dist
|
||||
closest_distance = dist
|
||||
if write_grad == 1:
|
||||
cl_pt = sign * (delta) / dis_length
|
||||
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
|
||||
closest_point = grad_vec
|
||||
|
||||
else:
|
||||
dist = dist + in_rad
|
||||
|
||||
if dist > 0:
|
||||
if dist == in_rad:
|
||||
cl_pt = sign * (delta) / (dist)
|
||||
else:
|
||||
cl_pt = sign * (delta) / dis_length
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
cl_pt = (1.0 / eta) * dist * cl_pt
|
||||
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
|
||||
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
i += 1
|
||||
|
||||
if closest_distance == 0:
|
||||
if closest_distance == 0 and compute_esdf != 1:
|
||||
if sparsity_idx[tid] == uint_zero:
|
||||
return
|
||||
sparsity_idx[tid] = uint_zero
|
||||
@@ -850,8 +495,7 @@ def get_closest_pt_batch_env(
|
||||
sparsity_idx[tid] = uint_one
|
||||
if write_grad == 1:
|
||||
# compute gradient:
|
||||
if closest_distance > 0.0:
|
||||
closest_distance = weight[0]
|
||||
closest_distance = weight[0]
|
||||
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
|
||||
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
|
||||
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
|
||||
@@ -871,62 +515,43 @@ class SdfMeshWarpPy(torch.autograd.Function):
|
||||
mesh_pose_inverse,
|
||||
mesh_enable,
|
||||
n_env_mesh,
|
||||
max_dist=0.05,
|
||||
max_dist,
|
||||
env_query_idx=None,
|
||||
return_loss=False,
|
||||
compute_esdf=False,
|
||||
):
|
||||
b, h, n, _ = query_spheres.shape
|
||||
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
# launch
|
||||
wp.launch(
|
||||
kernel=get_closest_pt,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
max_dist,
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
else:
|
||||
wp.launch(
|
||||
kernel=get_closest_pt_batch_env,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
max_dist,
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
use_batch_env = False
|
||||
env_query_idx = n_env_mesh
|
||||
|
||||
wp.launch(
|
||||
kernel=get_closest_pt_batch_env,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
wp.from_torch(max_dist, dtype=wp.float32),
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
|
||||
use_batch_env,
|
||||
compute_esdf,
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
ctx.return_loss = return_loss
|
||||
ctx.save_for_backward(out_grad)
|
||||
return out_cost
|
||||
@@ -939,7 +564,22 @@ class SdfMeshWarpPy(torch.autograd.Function):
|
||||
grad_sph = r
|
||||
if ctx.return_loss:
|
||||
grad_sph = r * grad_output.unsqueeze(-1)
|
||||
return grad_sph, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
return (
|
||||
grad_sph,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class SweptSdfMeshWarpPy(torch.autograd.Function):
|
||||
@@ -957,69 +597,46 @@ class SweptSdfMeshWarpPy(torch.autograd.Function):
|
||||
mesh_pose_inverse,
|
||||
mesh_enable,
|
||||
n_env_mesh,
|
||||
max_dist,
|
||||
sweep_steps=1,
|
||||
enable_speed_metric=False,
|
||||
max_dist=0.05,
|
||||
env_query_idx=None,
|
||||
return_loss=False,
|
||||
):
|
||||
b, h, n, _ = query_spheres.shape
|
||||
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
wp.launch(
|
||||
kernel=get_swept_closest_pt,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(speed_dt),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
max_dist,
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
else:
|
||||
wp.launch(
|
||||
kernel=get_swept_closest_pt_batch_env,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(speed_dt),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
max_dist,
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
use_batch_env = False
|
||||
env_query_idx = n_env_mesh
|
||||
|
||||
wp.launch(
|
||||
kernel=get_swept_closest_pt_batch_env,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(speed_dt),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
wp.from_torch(max_dist, dtype=wp.float32),
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
|
||||
use_batch_env,
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
ctx.return_loss = return_loss
|
||||
ctx.save_for_backward(out_grad)
|
||||
return out_cost
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.curobolib.geom import SdfSphereOBB, SdfSweptSphereOBB
|
||||
from curobo.geom.types import Cuboid, Mesh, Obstacle, WorldConfig, batch_tensor_cube
|
||||
from curobo.geom.types import Cuboid, Mesh, Obstacle, VoxelGrid, WorldConfig, batch_tensor_cube
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.math import Pose
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
@@ -39,10 +39,14 @@ class CollisionBuffer:
|
||||
def initialize_from_shape(cls, shape: torch.Size, tensor_args: TensorDeviceType):
|
||||
batch, horizon, n_spheres, _ = shape
|
||||
distance_buffer = torch.zeros(
|
||||
(batch, horizon, n_spheres), device=tensor_args.device, dtype=tensor_args.dtype
|
||||
(batch, horizon, n_spheres),
|
||||
device=tensor_args.device,
|
||||
dtype=tensor_args.collision_distance_dtype,
|
||||
)
|
||||
grad_distance_buffer = torch.zeros(
|
||||
(batch, horizon, n_spheres, 4), device=tensor_args.device, dtype=tensor_args.dtype
|
||||
(batch, horizon, n_spheres, 4),
|
||||
device=tensor_args.device,
|
||||
dtype=tensor_args.collision_gradient_dtype,
|
||||
)
|
||||
sparsity_idx = torch.zeros(
|
||||
(batch, horizon, n_spheres),
|
||||
@@ -54,10 +58,14 @@ class CollisionBuffer:
|
||||
def _update_from_shape(self, shape: torch.Size, tensor_args: TensorDeviceType):
|
||||
batch, horizon, n_spheres, _ = shape
|
||||
self.distance_buffer = torch.zeros(
|
||||
(batch, horizon, n_spheres), device=tensor_args.device, dtype=tensor_args.dtype
|
||||
(batch, horizon, n_spheres),
|
||||
device=tensor_args.device,
|
||||
dtype=tensor_args.collision_distance_dtype,
|
||||
)
|
||||
self.grad_distance_buffer = torch.zeros(
|
||||
(batch, horizon, n_spheres, 4), device=tensor_args.device, dtype=tensor_args.dtype
|
||||
(batch, horizon, n_spheres, 4),
|
||||
device=tensor_args.device,
|
||||
dtype=tensor_args.collision_gradient_dtype,
|
||||
)
|
||||
self.sparsity_index_buffer = torch.zeros(
|
||||
(batch, horizon, n_spheres),
|
||||
@@ -100,6 +108,7 @@ class CollisionQueryBuffer:
|
||||
primitive_collision_buffer: Optional[CollisionBuffer] = None
|
||||
mesh_collision_buffer: Optional[CollisionBuffer] = None
|
||||
blox_collision_buffer: Optional[CollisionBuffer] = None
|
||||
voxel_collision_buffer: Optional[CollisionBuffer] = None
|
||||
shape: Optional[torch.Size] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -110,6 +119,8 @@ class CollisionQueryBuffer:
|
||||
self.shape = self.mesh_collision_buffer.shape
|
||||
elif self.blox_collision_buffer is not None:
|
||||
self.shape = self.blox_collision_buffer.shape
|
||||
elif self.voxel_collision_buffer is not None:
|
||||
self.shape = self.voxel_collision_buffer.shape
|
||||
|
||||
def __mul__(self, scalar: float):
|
||||
if self.primitive_collision_buffer is not None:
|
||||
@@ -118,17 +129,27 @@ class CollisionQueryBuffer:
|
||||
self.mesh_collision_buffer = self.mesh_collision_buffer * scalar
|
||||
if self.blox_collision_buffer is not None:
|
||||
self.blox_collision_buffer = self.blox_collision_buffer * scalar
|
||||
if self.voxel_collision_buffer is not None:
|
||||
self.voxel_collision_buffer = self.voxel_collision_buffer * scalar
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
prim_buffer = mesh_buffer = blox_buffer = None
|
||||
prim_buffer = mesh_buffer = blox_buffer = voxel_buffer = None
|
||||
if self.primitive_collision_buffer is not None:
|
||||
prim_buffer = self.primitive_collision_buffer.clone()
|
||||
if self.mesh_collision_buffer is not None:
|
||||
mesh_buffer = self.mesh_collision_buffer.clone()
|
||||
if self.blox_collision_buffer is not None:
|
||||
blox_buffer = self.blox_collision_buffer.clone()
|
||||
return CollisionQueryBuffer(prim_buffer, mesh_buffer, blox_buffer, self.shape)
|
||||
if self.voxel_collision_buffer is not None:
|
||||
voxel_buffer = self.voxel_collision_buffer.clone()
|
||||
return CollisionQueryBuffer(
|
||||
prim_buffer,
|
||||
mesh_buffer,
|
||||
blox_buffer,
|
||||
voxel_collision_buffer=voxel_buffer,
|
||||
shape=self.shape,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def initialize_from_shape(
|
||||
@@ -137,14 +158,18 @@ class CollisionQueryBuffer:
|
||||
tensor_args: TensorDeviceType,
|
||||
collision_types: Dict[str, bool],
|
||||
):
|
||||
primitive_buffer = mesh_buffer = blox_buffer = None
|
||||
primitive_buffer = mesh_buffer = blox_buffer = voxel_buffer = None
|
||||
if "primitive" in collision_types and collision_types["primitive"]:
|
||||
primitive_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
if "mesh" in collision_types and collision_types["mesh"]:
|
||||
mesh_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
if "blox" in collision_types and collision_types["blox"]:
|
||||
blox_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
return CollisionQueryBuffer(primitive_buffer, mesh_buffer, blox_buffer)
|
||||
if "voxel" in collision_types and collision_types["voxel"]:
|
||||
voxel_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
return CollisionQueryBuffer(
|
||||
primitive_buffer, mesh_buffer, blox_buffer, voxel_collision_buffer=voxel_buffer
|
||||
)
|
||||
|
||||
def create_from_shape(
|
||||
self,
|
||||
@@ -160,8 +185,9 @@ class CollisionQueryBuffer:
|
||||
self.mesh_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
if "blox" in collision_types and collision_types["blox"]:
|
||||
self.blox_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
if "voxel" in collision_types and collision_types["voxel"]:
|
||||
self.voxel_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
self.shape = shape
|
||||
# return self
|
||||
|
||||
def update_buffer_shape(
|
||||
self,
|
||||
@@ -169,12 +195,10 @@ class CollisionQueryBuffer:
|
||||
tensor_args: TensorDeviceType,
|
||||
collision_types: Optional[Dict[str, bool]],
|
||||
):
|
||||
# print(shape, self.shape)
|
||||
# update buffers:
|
||||
assert len(shape) == 4 # shape is: batch, horizon, n_spheres, 4
|
||||
if self.shape is None: # buffers not initialized:
|
||||
self.create_from_shape(shape, tensor_args, collision_types)
|
||||
# print("Creating new memory", self.shape)
|
||||
else:
|
||||
# update buffers if shape doesn't match:
|
||||
# TODO: allow for dynamic change of collision_types
|
||||
@@ -185,6 +209,8 @@ class CollisionQueryBuffer:
|
||||
self.mesh_collision_buffer.update_buffer_shape(shape, tensor_args)
|
||||
if self.blox_collision_buffer is not None:
|
||||
self.blox_collision_buffer.update_buffer_shape(shape, tensor_args)
|
||||
if self.voxel_collision_buffer is not None:
|
||||
self.voxel_collision_buffer.update_buffer_shape(shape, tensor_args)
|
||||
self.shape = shape
|
||||
|
||||
def get_gradient_buffer(
|
||||
@@ -208,6 +234,12 @@ class CollisionQueryBuffer:
|
||||
current_buffer = blox_buffer.clone()
|
||||
else:
|
||||
current_buffer += blox_buffer
|
||||
if self.voxel_collision_buffer is not None:
|
||||
voxel_buffer = self.voxel_collision_buffer.grad_distance_buffer
|
||||
if current_buffer is None:
|
||||
current_buffer = voxel_buffer.clone()
|
||||
else:
|
||||
current_buffer += voxel_buffer
|
||||
|
||||
return current_buffer
|
||||
|
||||
@@ -221,6 +253,7 @@ class CollisionCheckerType(Enum):
|
||||
PRIMITIVE = "PRIMITIVE"
|
||||
BLOX = "BLOX"
|
||||
MESH = "MESH"
|
||||
VOXEL = "VOXEL"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -230,11 +263,13 @@ class WorldCollisionConfig:
|
||||
cache: Optional[Dict[Obstacle, int]] = None
|
||||
n_envs: int = 1
|
||||
checker_type: CollisionCheckerType = CollisionCheckerType.PRIMITIVE
|
||||
max_distance: float = 0.01
|
||||
max_distance: Union[torch.Tensor, float] = 0.01
|
||||
|
||||
def __post_init__(self):
|
||||
if self.world_model is not None and isinstance(self.world_model, list):
|
||||
self.n_envs = len(self.world_model)
|
||||
if isinstance(self.max_distance, float):
|
||||
self.max_distance = self.tensor_args.to_device([self.max_distance])
|
||||
|
||||
@staticmethod
|
||||
def load_from_dict(
|
||||
@@ -261,6 +296,8 @@ class WorldCollision(WorldCollisionConfig):
|
||||
if config is not None:
|
||||
WorldCollisionConfig.__init__(self, **vars(config))
|
||||
self.collision_types = {} # Use this dictionary to store collision types
|
||||
self._cache_voxelization = None
|
||||
self._cache_voxelization_collision_buffer = None
|
||||
|
||||
def load_collision_model(self, world_model: WorldConfig):
|
||||
raise NotImplementedError
|
||||
@@ -273,6 +310,8 @@ class WorldCollision(WorldCollisionConfig):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
"""
|
||||
Computes the signed distance via analytic function
|
||||
@@ -310,6 +349,7 @@ class WorldCollision(WorldCollisionConfig):
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -338,6 +378,118 @@ class WorldCollision(WorldCollisionConfig):
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_voxels_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
|
||||
voxel_size: float = 0.02,
|
||||
) -> Union[List[Cuboid], torch.Tensor]:
|
||||
new_grid = self.get_occupancy_in_bounding_box(cuboid, voxel_size)
|
||||
occupied = new_grid.get_occupied_voxels(0.0)
|
||||
return occupied
|
||||
|
||||
def clear_voxelization_cache(self):
|
||||
self._cache_voxelization = None
|
||||
|
||||
def update_cache_voxelization(self, new_grid: VoxelGrid):
|
||||
if (
|
||||
self._cache_voxelization is None
|
||||
or self._cache_voxelization.voxel_size != new_grid.voxel_size
|
||||
or self._cache_voxelization.dims != new_grid.dims
|
||||
):
|
||||
self._cache_voxelization = new_grid
|
||||
self._cache_voxelization.xyzr_tensor = self._cache_voxelization.create_xyzr_tensor(
|
||||
transform_to_origin=True, tensor_args=self.tensor_args
|
||||
)
|
||||
self._cache_voxelization_collision_buffer = CollisionQueryBuffer()
|
||||
xyzr = self._cache_voxelization.xyzr_tensor.view(-1, 1, 1, 4)
|
||||
|
||||
self._cache_voxelization_collision_buffer.update_buffer_shape(
|
||||
xyzr.shape,
|
||||
self.tensor_args,
|
||||
self.collision_types,
|
||||
)
|
||||
|
||||
def get_occupancy_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
|
||||
voxel_size: float = 0.02,
|
||||
) -> VoxelGrid:
|
||||
new_grid = VoxelGrid(
|
||||
name=cuboid.name, dims=cuboid.dims, pose=cuboid.pose, voxel_size=voxel_size
|
||||
)
|
||||
|
||||
self.update_cache_voxelization(new_grid)
|
||||
|
||||
xyzr = self._cache_voxelization.xyzr_tensor
|
||||
|
||||
xyzr = xyzr.view(-1, 1, 1, 4)
|
||||
|
||||
weight = self.tensor_args.to_device([1.0])
|
||||
act_distance = self.tensor_args.to_device([0.0])
|
||||
|
||||
d_sph = self.get_sphere_collision(
|
||||
xyzr,
|
||||
self._cache_voxelization_collision_buffer,
|
||||
weight,
|
||||
act_distance,
|
||||
)
|
||||
d_sph = d_sph.reshape(-1)
|
||||
new_grid.xyzr_tensor = self._cache_voxelization.xyzr_tensor
|
||||
new_grid.feature_tensor = d_sph
|
||||
return new_grid
|
||||
|
||||
def get_esdf_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
|
||||
voxel_size: float = 0.02,
|
||||
dtype=torch.float32,
|
||||
) -> VoxelGrid:
|
||||
new_grid = VoxelGrid(
|
||||
name=cuboid.name,
|
||||
dims=cuboid.dims,
|
||||
pose=cuboid.pose,
|
||||
voxel_size=voxel_size,
|
||||
feature_dtype=dtype,
|
||||
)
|
||||
|
||||
self.update_cache_voxelization(new_grid)
|
||||
|
||||
xyzr = self._cache_voxelization.xyzr_tensor
|
||||
voxel_shape = xyzr.shape
|
||||
xyzr = xyzr.view(-1, 1, 1, 4)
|
||||
|
||||
weight = self.tensor_args.to_device([1.0])
|
||||
|
||||
d_sph = self.get_sphere_distance(
|
||||
xyzr,
|
||||
self._cache_voxelization_collision_buffer,
|
||||
weight,
|
||||
self.max_distance,
|
||||
sum_collisions=False,
|
||||
compute_esdf=True,
|
||||
)
|
||||
|
||||
d_sph = d_sph.reshape(-1)
|
||||
voxel_grid = self._cache_voxelization
|
||||
voxel_grid.feature_tensor = d_sph
|
||||
|
||||
return voxel_grid
|
||||
|
||||
def get_mesh_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
|
||||
voxel_size: float = 0.02,
|
||||
) -> Mesh:
|
||||
voxels = self.get_voxels_in_bounding_box(cuboid, voxel_size)
|
||||
# voxels = voxels.cpu().numpy()
|
||||
# cuboids = [Cuboid(name="c_"+str(x), pose=[voxels[x,0],voxels[x,1],voxels[x,2], 1,0,0,0], dims=[voxel_size, voxel_size, voxel_size]) for x in range(voxels.shape[0])]
|
||||
# mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0]
|
||||
mesh = Mesh.from_pointcloud(
|
||||
voxels[:, :3].detach().cpu().numpy(),
|
||||
pitch=voxel_size * 1.1,
|
||||
)
|
||||
return mesh
|
||||
|
||||
|
||||
class WorldPrimitiveCollision(WorldCollision):
|
||||
"""World Oriented Bounding Box representation object
|
||||
@@ -354,6 +506,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
self._env_n_obbs = None
|
||||
self._env_obbs_names = None
|
||||
self._init_cache()
|
||||
|
||||
if self.world_model is not None:
|
||||
if isinstance(self.world_model, list):
|
||||
self.load_batch_collision_model(self.world_model)
|
||||
@@ -656,6 +809,8 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
|
||||
raise ValueError("Primitive Collision has no obstacles")
|
||||
@@ -673,6 +828,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
collision_query_buffer.primitive_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self._cube_tensor_list[0],
|
||||
self._cube_tensor_list[0],
|
||||
self._cube_tensor_list[1],
|
||||
@@ -687,6 +843,8 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
sum_collisions,
|
||||
compute_esdf,
|
||||
)
|
||||
|
||||
return dist
|
||||
@@ -699,6 +857,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
**kwargs,
|
||||
):
|
||||
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
|
||||
raise ValueError("Primitive Collision has no obstacles")
|
||||
@@ -717,6 +876,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
collision_query_buffer.primitive_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self._cube_tensor_list[0],
|
||||
self._cube_tensor_list[0],
|
||||
self._cube_tensor_list[1],
|
||||
@@ -728,7 +888,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
h,
|
||||
n,
|
||||
query_sphere.requires_grad,
|
||||
False,
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
)
|
||||
@@ -745,6 +905,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
"""
|
||||
Computes the signed distance via analytic function
|
||||
@@ -784,6 +945,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
sum_collisions,
|
||||
)
|
||||
|
||||
return dist
|
||||
@@ -836,7 +998,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
query_sphere.requires_grad,
|
||||
False,
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
)
|
||||
@@ -845,70 +1007,5 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
|
||||
def clear_cache(self):
|
||||
if self._cube_tensor_list is not None:
|
||||
self._cube_tensor_list[2][:] = 1
|
||||
self._cube_tensor_list[2][:] = 0
|
||||
self._env_n_obbs[:] = 0
|
||||
|
||||
def get_voxels_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid,
|
||||
voxel_size: float = 0.02,
|
||||
) -> Union[List[Cuboid], torch.Tensor]:
|
||||
bounds = cuboid.dims
|
||||
low = [-bounds[0], -bounds[1], -bounds[2]]
|
||||
high = [bounds[0], bounds[1], bounds[2]]
|
||||
trange = [h - l for l, h in zip(low, high)]
|
||||
|
||||
x = torch.linspace(
|
||||
-bounds[0], bounds[0], int(trange[0] // voxel_size) + 1, device=self.tensor_args.device
|
||||
)
|
||||
y = torch.linspace(
|
||||
-bounds[1], bounds[1], int(trange[1] // voxel_size) + 1, device=self.tensor_args.device
|
||||
)
|
||||
z = torch.linspace(
|
||||
-bounds[2], bounds[2], int(trange[2] // voxel_size) + 1, device=self.tensor_args.device
|
||||
)
|
||||
w, l, h = x.shape[0], y.shape[0], z.shape[0]
|
||||
xyz = (
|
||||
torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3)
|
||||
)
|
||||
|
||||
pose = Pose.from_list(cuboid.pose, tensor_args=self.tensor_args)
|
||||
xyz = pose.transform_points(xyz.contiguous())
|
||||
r = torch.zeros_like(xyz[:, 0:1]) + voxel_size
|
||||
xyzr = torch.cat([xyz, r], dim=1)
|
||||
xyzr = xyzr.reshape(-1, 1, 1, 4)
|
||||
collision_buffer = CollisionQueryBuffer()
|
||||
collision_buffer.update_buffer_shape(
|
||||
xyzr.shape,
|
||||
self.tensor_args,
|
||||
self.collision_types,
|
||||
)
|
||||
weight = self.tensor_args.to_device([1.0])
|
||||
act_distance = self.tensor_args.to_device([0.0])
|
||||
|
||||
d_sph = self.get_sphere_collision(
|
||||
xyzr,
|
||||
collision_buffer,
|
||||
weight,
|
||||
act_distance,
|
||||
)
|
||||
d_sph = d_sph.reshape(-1)
|
||||
xyzr = xyzr.reshape(-1, 4)
|
||||
# get occupied voxels:
|
||||
occupied = xyzr[d_sph > 0.0]
|
||||
return occupied
|
||||
|
||||
def get_mesh_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid,
|
||||
voxel_size: float = 0.02,
|
||||
) -> Mesh:
|
||||
voxels = self.get_voxels_in_bounding_box(cuboid, voxel_size)
|
||||
# voxels = voxels.cpu().numpy()
|
||||
# cuboids = [Cuboid(name="c_"+str(x), pose=[voxels[x,0],voxels[x,1],voxels[x,2], 1,0,0,0], dims=[voxel_size, voxel_size, voxel_size]) for x in range(voxels.shape[0])]
|
||||
# mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0]
|
||||
mesh = Mesh.from_pointcloud(
|
||||
voxels[:, :3].detach().cpu().numpy(),
|
||||
pitch=voxel_size * 1.1,
|
||||
)
|
||||
return mesh
|
||||
|
||||
@@ -176,6 +176,8 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
if "blox" not in self.collision_types or not self.collision_types["blox"]:
|
||||
return super().get_sphere_distance(
|
||||
@@ -185,6 +187,8 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
activation_distance,
|
||||
env_query_idx,
|
||||
return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
|
||||
d = self._get_blox_sdf(
|
||||
@@ -205,8 +209,13 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
activation_distance,
|
||||
env_query_idx,
|
||||
return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
d = d + d_base
|
||||
if compute_esdf:
|
||||
d = torch.maximum(d, d_base)
|
||||
else:
|
||||
d = d + d_base
|
||||
|
||||
return d
|
||||
|
||||
@@ -262,6 +271,7 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
if "blox" not in self.collision_types or not self.collision_types["blox"]:
|
||||
return super().get_swept_sphere_distance(
|
||||
@@ -274,6 +284,7 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
enable_speed_metric,
|
||||
env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
|
||||
d = self._get_blox_swept_sdf(
|
||||
@@ -301,6 +312,7 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
enable_speed_metric,
|
||||
env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
d = d + d_base
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
self._mesh_tensor_list[0][env_idx, :max_nmesh] = w_mid
|
||||
self._mesh_tensor_list[1][env_idx, :max_nmesh, :7] = w_inv_pose
|
||||
self._mesh_tensor_list[2][env_idx, :max_nmesh] = 1
|
||||
self._mesh_tensor_list[2][env_idx, max_nmesh:] = 0
|
||||
|
||||
self._env_mesh_names[env_idx][:max_nmesh] = name_list
|
||||
self._env_n_mesh[env_idx] = max_nmesh
|
||||
@@ -355,6 +356,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx=None,
|
||||
return_loss=False,
|
||||
compute_esdf=False,
|
||||
):
|
||||
d = SdfMeshWarpPy.apply(
|
||||
query_spheres,
|
||||
@@ -370,6 +372,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
self.max_distance,
|
||||
env_query_idx,
|
||||
return_loss,
|
||||
compute_esdf,
|
||||
)
|
||||
return d
|
||||
|
||||
@@ -397,9 +400,9 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
self._mesh_tensor_list[1],
|
||||
self._mesh_tensor_list[2],
|
||||
self._env_n_mesh,
|
||||
self.max_distance,
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
self.max_distance,
|
||||
env_query_idx,
|
||||
return_loss,
|
||||
)
|
||||
@@ -413,6 +416,8 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
# TODO: if no mesh object exist, call primitive
|
||||
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
|
||||
@@ -423,6 +428,8 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance=activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
|
||||
d = self._get_sdf(
|
||||
@@ -432,6 +439,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance=activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
|
||||
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
|
||||
@@ -443,8 +451,13 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance=activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
d_val = d.view(d_prim.shape) + d_prim
|
||||
if compute_esdf:
|
||||
d_val = torch.maximum(d.view(d_prim.shape), d_prim)
|
||||
else:
|
||||
d_val = d.view(d_prim.shape) + d_prim
|
||||
return d_val
|
||||
|
||||
def get_sphere_collision(
|
||||
@@ -455,6 +468,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx=None,
|
||||
return_loss=False,
|
||||
**kwargs,
|
||||
):
|
||||
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
|
||||
return super().get_sphere_collision(
|
||||
@@ -501,6 +515,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
# log_warn("Swept: Mesh + Primitive Collision Checking is experimental")
|
||||
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
|
||||
@@ -514,6 +529,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
|
||||
d = self._get_swept_sdf(
|
||||
@@ -540,6 +556,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
d_val = d.view(d_prim.shape) + d_prim
|
||||
|
||||
@@ -602,4 +619,11 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
self._wp_mesh_cache = {}
|
||||
if self._mesh_tensor_list is not None:
|
||||
self._mesh_tensor_list[2][:] = 0
|
||||
if self._env_n_mesh is not None:
|
||||
self._env_n_mesh[:] = 0
|
||||
if self._env_mesh_names is not None:
|
||||
self._env_mesh_names = [
|
||||
[None for _ in range(self.cache["mesh"])] for _ in range(self.n_envs)
|
||||
]
|
||||
|
||||
super().clear_cache()
|
||||
|
||||
699
src/curobo/geom/sdf/world_voxel.py
Normal file
699
src/curobo/geom/sdf/world_voxel.py
Normal file
@@ -0,0 +1,699 @@
|
||||
#
|
||||
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
# Standard Library
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Third Party
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.curobolib.geom import SdfSphereVoxel, SdfSweptSphereVoxel
|
||||
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollisionConfig
|
||||
from curobo.geom.sdf.world_mesh import WorldMeshCollision
|
||||
from curobo.geom.types import VoxelGrid, WorldConfig
|
||||
from curobo.types.math import Pose
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
|
||||
|
||||
class WorldVoxelCollision(WorldMeshCollision):
|
||||
"""Voxel grid representation of World, with each voxel containing Euclidean Signed Distance."""
|
||||
|
||||
def __init__(self, config: WorldCollisionConfig):
|
||||
self._env_n_voxels = None
|
||||
self._voxel_tensor_list = None
|
||||
self._env_voxel_names = None
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
def _init_cache(self):
|
||||
if (
|
||||
self.cache is not None
|
||||
and "voxel" in self.cache
|
||||
and self.cache["voxel"] not in [None, 0]
|
||||
):
|
||||
self._create_voxel_cache(self.cache["voxel"])
|
||||
|
||||
def _create_voxel_cache(self, voxel_cache: Dict[str, Any]):
|
||||
n_layers = voxel_cache["layers"]
|
||||
dims = voxel_cache["dims"]
|
||||
voxel_size = voxel_cache["voxel_size"]
|
||||
feature_dtype = voxel_cache["feature_dtype"]
|
||||
n_voxels = int(
|
||||
math.floor(dims[0] / voxel_size)
|
||||
* math.floor(dims[1] / voxel_size)
|
||||
* math.floor(dims[2] / voxel_size)
|
||||
)
|
||||
|
||||
voxel_params = torch.zeros(
|
||||
(self.n_envs, n_layers, 4),
|
||||
dtype=self.tensor_args.dtype,
|
||||
device=self.tensor_args.device,
|
||||
)
|
||||
voxel_pose = torch.zeros(
|
||||
(self.n_envs, n_layers, 8),
|
||||
dtype=self.tensor_args.dtype,
|
||||
device=self.tensor_args.device,
|
||||
)
|
||||
voxel_pose[..., 3] = 1.0
|
||||
voxel_enable = torch.zeros(
|
||||
(self.n_envs, n_layers), dtype=torch.uint8, device=self.tensor_args.device
|
||||
)
|
||||
self._env_n_voxels = torch.zeros(
|
||||
(self.n_envs), device=self.tensor_args.device, dtype=torch.int32
|
||||
)
|
||||
voxel_features = torch.zeros(
|
||||
(self.n_envs, n_layers, n_voxels, 1),
|
||||
device=self.tensor_args.device,
|
||||
dtype=feature_dtype,
|
||||
)
|
||||
|
||||
self._voxel_tensor_list = [voxel_params, voxel_pose, voxel_enable, voxel_features]
|
||||
self.collision_types["voxel"] = True
|
||||
self._env_voxel_names = [[None for _ in range(n_layers)] for _ in range(self.n_envs)]
|
||||
|
||||
def load_collision_model(
|
||||
self, world_model: WorldConfig, env_idx=0, fix_cache_reference: bool = False
|
||||
):
|
||||
self._load_collision_model_in_cache(
|
||||
world_model, env_idx, fix_cache_reference=fix_cache_reference
|
||||
)
|
||||
return super().load_collision_model(
|
||||
world_model, env_idx=env_idx, fix_cache_reference=fix_cache_reference
|
||||
)
|
||||
|
||||
def _load_collision_model_in_cache(
|
||||
self, world_config: WorldConfig, env_idx: int = 0, fix_cache_reference: bool = False
|
||||
):
|
||||
"""TODO:
|
||||
|
||||
_extended_summary_
|
||||
|
||||
Args:
|
||||
world_config: _description_
|
||||
env_idx: _description_
|
||||
fix_cache_reference: _description_
|
||||
"""
|
||||
voxel_objs = world_config.voxel
|
||||
max_obs = len(voxel_objs)
|
||||
self.world_model = world_config
|
||||
if max_obs < 1:
|
||||
log_info("No Voxel objs")
|
||||
return
|
||||
if self._voxel_tensor_list is None or self._voxel_tensor_list[0].shape[1] < max_obs:
|
||||
if not fix_cache_reference:
|
||||
log_info("Creating Voxel cache" + str(max_obs))
|
||||
self._create_voxel_cache(
|
||||
{
|
||||
"layers": max_obs,
|
||||
"dims": voxel_objs[0].dims,
|
||||
"voxel_size": voxel_objs[0].voxel_size,
|
||||
"feature_dtype": voxel_objs[0].feature_dtype,
|
||||
}
|
||||
)
|
||||
else:
|
||||
log_error("number of OBB is larger than collision cache, create larger cache.")
|
||||
|
||||
# load as a batch:
|
||||
pose_batch = [c.pose for c in voxel_objs]
|
||||
dims_batch = [c.dims for c in voxel_objs]
|
||||
names_batch = [c.name for c in voxel_objs]
|
||||
size_batch = [c.voxel_size for c in voxel_objs]
|
||||
voxel_batch = self._batch_tensor_voxel(pose_batch, dims_batch, size_batch)
|
||||
self._voxel_tensor_list[0][env_idx, :max_obs, :] = voxel_batch[0]
|
||||
self._voxel_tensor_list[1][env_idx, :max_obs, :7] = voxel_batch[1]
|
||||
|
||||
self._voxel_tensor_list[2][env_idx, :max_obs] = 1 # enabling obstacle
|
||||
|
||||
self._voxel_tensor_list[2][env_idx, max_obs:] = 0 # disabling obstacle
|
||||
|
||||
# copy voxel grid features:
|
||||
|
||||
self._env_n_voxels[env_idx] = max_obs
|
||||
self._env_voxel_names[env_idx][:max_obs] = names_batch
|
||||
self.collision_types["voxel"] = True
|
||||
|
||||
def _batch_tensor_voxel(
|
||||
self, pose: List[List[float]], dims: List[float], voxel_size: List[float]
|
||||
):
|
||||
w_T_b = Pose.from_batch_list(pose, tensor_args=self.tensor_args)
|
||||
b_T_w = w_T_b.inverse()
|
||||
dims_t = torch.as_tensor(
|
||||
np.array(dims), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
size_t = torch.as_tensor(
|
||||
np.array(voxel_size), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
).unsqueeze(-1)
|
||||
params_t = torch.cat([dims_t, size_t], dim=-1)
|
||||
|
||||
voxel_list = [params_t, b_T_w.get_pose_vector()]
|
||||
return voxel_list
|
||||
|
||||
def load_batch_collision_model(self, world_config_list: List[WorldConfig]):
|
||||
"""Load voxel grid for batched environments
|
||||
|
||||
_extended_summary_
|
||||
|
||||
Args:
|
||||
world_config_list: _description_
|
||||
|
||||
Returns:
|
||||
_description_
|
||||
"""
|
||||
log_error("Not Implemented")
|
||||
# First find largest number of cuboid:
|
||||
c_len = []
|
||||
pose_batch = []
|
||||
dims_batch = []
|
||||
names_batch = []
|
||||
vsize_batch = []
|
||||
for i in world_config_list:
|
||||
c = i.cuboid
|
||||
if c is not None:
|
||||
c_len.append(len(c))
|
||||
pose_batch.extend([i.pose for i in c])
|
||||
dims_batch.extend([i.dims for i in c])
|
||||
names_batch.extend([i.name for i in c])
|
||||
vsize_batch.extend([i.voxel_size for i in c])
|
||||
else:
|
||||
c_len.append(0)
|
||||
|
||||
max_obs = max(c_len)
|
||||
if max_obs < 1:
|
||||
log_warn("No obbs found")
|
||||
return
|
||||
# check if number of environments is same as config:
|
||||
reset_buffers = False
|
||||
if self._env_n_voxels is not None and len(world_config_list) != len(self._env_n_voxels):
|
||||
log_warn(
|
||||
"env_n_voxels is not same as world_config_list, reloading collision buffers (breaks CG)"
|
||||
)
|
||||
reset_buffers = True
|
||||
self.n_envs = len(world_config_list)
|
||||
self._env_n_voxels = torch.zeros(
|
||||
(self.n_envs), device=self.tensor_args.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
if self._voxel_tensor_list is not None and self._voxel_tensor_list[0].shape[1] < max_obs:
|
||||
log_warn(
|
||||
"number of obbs is greater than buffer, reloading collision buffers (breaks CG)"
|
||||
)
|
||||
reset_buffers = True
|
||||
# create cache if does not exist:
|
||||
if self._voxel_tensor_list is None or reset_buffers:
|
||||
log_info("Creating Obb cache" + str(max_obs))
|
||||
self._create_obb_cache(max_obs)
|
||||
|
||||
# load obstacles:
|
||||
## load data into gpu:
|
||||
voxel_batch = self._batch_tensor_voxel(pose_batch, dims_batch, vsize_batch)
|
||||
c_start = 0
|
||||
for i in range(len(self._env_n_voxels)):
|
||||
if c_len[i] > 0:
|
||||
# load obb:
|
||||
self._voxel_tensor_list[0][i, : c_len[i], :] = voxel_batch[0][
|
||||
c_start : c_start + c_len[i]
|
||||
]
|
||||
self._voxel_tensor_list[1][i, : c_len[i], :7] = voxel_batch[1][
|
||||
c_start : c_start + c_len[i]
|
||||
]
|
||||
self._voxel_tensor_list[2][i, : c_len[i]] = 1
|
||||
self._env_voxel_names[i][: c_len[i]] = names_batch[c_start : c_start + c_len[i]]
|
||||
self._voxel_tensor_list[2][i, c_len[i] :] = 0
|
||||
c_start += c_len[i]
|
||||
self._env_n_voxels[:] = torch.as_tensor(
|
||||
c_len, dtype=torch.int32, device=self.tensor_args.device
|
||||
)
|
||||
self.collision_types["voxel"] = True
|
||||
|
||||
return super().load_batch_collision_model(world_config_list)
|
||||
|
||||
def enable_obstacle(
|
||||
self,
|
||||
name: str,
|
||||
enable: bool = True,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
if self._env_voxel_names is not None and name in self._env_voxel_names[env_idx]:
|
||||
self.enable_voxel(enable, name, None, env_idx)
|
||||
else:
|
||||
return super().enable_obstacle(name, enable, env_idx)
|
||||
|
||||
def enable_voxel(
|
||||
self,
|
||||
enable: bool = True,
|
||||
name: Optional[str] = None,
|
||||
env_obj_idx: Optional[torch.Tensor] = None,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
"""Update obstacle dimensions
|
||||
|
||||
Args:
|
||||
obj_dims (torch.Tensor): [dim.x,dim.y, dim.z], give as [b,3]
|
||||
obj_idx (torch.Tensor or int):
|
||||
|
||||
"""
|
||||
if env_obj_idx is not None:
|
||||
self._voxel_tensor_list[2][env_obj_idx] = int(enable) # enable == 1
|
||||
else:
|
||||
# find index of given name:
|
||||
obs_idx = self.get_voxel_idx(name, env_idx)
|
||||
|
||||
self._voxel_tensor_list[2][env_idx, obs_idx] = int(enable)
|
||||
|
||||
def update_obstacle_pose(
|
||||
self,
|
||||
name: str,
|
||||
w_obj_pose: Pose,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
if self._env_voxel_names is not None and name in self._env_voxel_names[env_idx]:
|
||||
self.update_voxel_pose(name=name, w_obj_pose=w_obj_pose, env_idx=env_idx)
|
||||
else:
|
||||
log_error("obstacle not found in OBB world model: " + name)
|
||||
|
||||
def update_voxel_data(self, new_voxel: VoxelGrid, env_idx: int = 0):
|
||||
obs_idx = self.get_voxel_idx(new_voxel.name, env_idx)
|
||||
self._voxel_tensor_list[3][env_idx, obs_idx, :, :] = new_voxel.feature_tensor.view(
|
||||
new_voxel.feature_tensor.shape[0], -1
|
||||
).to(dtype=self._voxel_tensor_list[3].dtype)
|
||||
self._voxel_tensor_list[0][env_idx, obs_idx, :3] = self.tensor_args.to_device(
|
||||
new_voxel.dims
|
||||
)
|
||||
self._voxel_tensor_list[0][env_idx, obs_idx, 3] = new_voxel.voxel_size
|
||||
self._voxel_tensor_list[1][env_idx, obs_idx, :7] = (
|
||||
Pose.from_list(new_voxel.pose, self.tensor_args).inverse().get_pose_vector()
|
||||
)
|
||||
self._voxel_tensor_list[2][env_idx, obs_idx] = int(True)
|
||||
|
||||
def update_voxel_features(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
env_obj_idx: Optional[torch.Tensor] = None,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
"""Update pose of a specific objects.
|
||||
This also updates the signed distance grid to account for the updated object pose.
|
||||
Args:
|
||||
obj_w_pose: Pose
|
||||
obj_idx:
|
||||
"""
|
||||
if env_obj_idx is not None:
|
||||
self._voxel_tensor_list[3][env_obj_idx, :] = features.to(
|
||||
dtype=self._voxel_tensor_list[3].dtype
|
||||
)
|
||||
else:
|
||||
obs_idx = self.get_voxel_idx(name, env_idx)
|
||||
self._voxel_tensor_list[3][env_idx, obs_idx, :] = features.to(
|
||||
dtype=self._voxel_tensor_list[3].dtype
|
||||
)
|
||||
|
||||
def update_voxel_pose(
|
||||
self,
|
||||
w_obj_pose: Optional[Pose] = None,
|
||||
obj_w_pose: Optional[Pose] = None,
|
||||
name: Optional[str] = None,
|
||||
env_obj_idx: Optional[torch.Tensor] = None,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
"""Update pose of a specific objects.
|
||||
This also updates the signed distance grid to account for the updated object pose.
|
||||
Args:
|
||||
obj_w_pose: Pose
|
||||
obj_idx:
|
||||
"""
|
||||
obj_w_pose = self._get_obstacle_poses(w_obj_pose, obj_w_pose)
|
||||
if env_obj_idx is not None:
|
||||
self._voxel_tensor_list[1][env_obj_idx, :7] = obj_w_pose.get_pose_vector()
|
||||
else:
|
||||
obs_idx = self.get_voxel_idx(name, env_idx)
|
||||
self._voxel_tensor_list[1][env_idx, obs_idx, :7] = obj_w_pose.get_pose_vector()
|
||||
|
||||
def get_voxel_idx(
|
||||
self,
|
||||
name: str,
|
||||
env_idx: int = 0,
|
||||
) -> int:
|
||||
if name not in self._env_voxel_names[env_idx]:
|
||||
log_error("Obstacle with name: " + name + " not found in current world", exc_info=True)
|
||||
return self._env_voxel_names[env_idx].index(name)
|
||||
|
||||
def get_voxel_grid(
|
||||
self,
|
||||
name: str,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
obs_idx = self.get_voxel_idx(name, env_idx)
|
||||
voxel_params = np.round(
|
||||
self._voxel_tensor_list[0][env_idx, obs_idx, :].cpu().numpy().astype(np.float64), 6
|
||||
).tolist()
|
||||
voxel_pose = Pose(
|
||||
position=self._voxel_tensor_list[1][env_idx, obs_idx, :3],
|
||||
quaternion=self._voxel_tensor_list[1][env_idx, obs_idx, 3:7],
|
||||
)
|
||||
voxel_features = self._voxel_tensor_list[3][env_idx, obs_idx, :]
|
||||
voxel_grid = VoxelGrid(
|
||||
name=name,
|
||||
dims=voxel_params[:3],
|
||||
pose=voxel_pose.to_list(),
|
||||
voxel_size=voxel_params[3],
|
||||
feature_tensor=voxel_features,
|
||||
)
|
||||
return voxel_grid
|
||||
|
||||
def get_sphere_distance(
|
||||
self,
|
||||
query_sphere,
|
||||
collision_query_buffer: CollisionQueryBuffer,
|
||||
weight: torch.Tensor,
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
|
||||
return super().get_sphere_distance(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
|
||||
b, h, n, _ = query_sphere.shape # This can be read from collision query buffer
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
dist = SdfSphereVoxel.apply(
|
||||
query_sphere,
|
||||
collision_query_buffer.voxel_collision_buffer.distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
query_sphere.requires_grad,
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
sum_collisions,
|
||||
compute_esdf,
|
||||
)
|
||||
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_sphere_distance(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight=weight,
|
||||
activation_distance=activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
if compute_esdf:
|
||||
d_val = torch.maximum(dist.view(d_prim.shape), d_prim)
|
||||
else:
|
||||
d_val = d_val.view(d_prim.shape) + d_prim
|
||||
|
||||
return d_val
|
||||
|
||||
def get_sphere_collision(
|
||||
self,
|
||||
query_sphere,
|
||||
collision_query_buffer: CollisionQueryBuffer,
|
||||
weight: torch.Tensor,
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
**kwargs,
|
||||
):
|
||||
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
|
||||
return super().get_sphere_collision(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
)
|
||||
|
||||
if return_loss:
|
||||
raise ValueError("cannot return loss for classification, use get_sphere_distance")
|
||||
b, h, n, _ = query_sphere.shape
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
dist = SdfSphereVoxel.apply(
|
||||
query_sphere,
|
||||
collision_query_buffer.voxel_collision_buffer.distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
query_sphere.requires_grad,
|
||||
False,
|
||||
use_batch_env,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_sphere_collision(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight,
|
||||
activation_distance=activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
)
|
||||
d_val = dist.view(d_prim.shape) + d_prim
|
||||
return d_val
|
||||
|
||||
def get_swept_sphere_distance(
|
||||
self,
|
||||
query_sphere,
|
||||
collision_query_buffer: CollisionQueryBuffer,
|
||||
weight: torch.Tensor,
|
||||
activation_distance: torch.Tensor,
|
||||
speed_dt: torch.Tensor,
|
||||
sweep_steps: int,
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
"""
|
||||
Computes the signed distance via analytic function
|
||||
Args:
|
||||
tensor_sphere: b, n, 4
|
||||
"""
|
||||
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
|
||||
return super().get_swept_sphere_distance(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight=weight,
|
||||
env_query_idx=env_query_idx,
|
||||
sweep_steps=sweep_steps,
|
||||
activation_distance=activation_distance,
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
b, h, n, _ = query_sphere.shape
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
|
||||
dist = SdfSweptSphereVoxel.apply(
|
||||
query_sphere,
|
||||
collision_query_buffer.voxel_collision_buffer.distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
speed_dt,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
query_sphere.requires_grad,
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
sum_collisions,
|
||||
)
|
||||
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_swept_sphere_distance(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight=weight,
|
||||
env_query_idx=env_query_idx,
|
||||
sweep_steps=sweep_steps,
|
||||
activation_distance=activation_distance,
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
|
||||
d_val = dist.view(d_prim.shape) + d_prim
|
||||
return d_val
|
||||
|
||||
def get_swept_sphere_collision(
|
||||
self,
|
||||
query_sphere,
|
||||
collision_query_buffer: CollisionQueryBuffer,
|
||||
weight: torch.Tensor,
|
||||
activation_distance: torch.Tensor,
|
||||
speed_dt: torch.Tensor,
|
||||
sweep_steps: int,
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
):
|
||||
"""
|
||||
Computes the signed distance via analytic function
|
||||
Args:
|
||||
tensor_sphere: b, n, 4
|
||||
"""
|
||||
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
|
||||
return super().get_swept_sphere_collision(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight=weight,
|
||||
env_query_idx=env_query_idx,
|
||||
sweep_steps=sweep_steps,
|
||||
activation_distance=activation_distance,
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
)
|
||||
if return_loss:
|
||||
raise ValueError("cannot return loss for classify, use get_swept_sphere_distance")
|
||||
b, h, n, _ = query_sphere.shape
|
||||
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
dist = SdfSweptSphereVoxel.apply(
|
||||
query_sphere,
|
||||
collision_query_buffer.voxel_collision_buffer.distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
speed_dt,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
query_sphere.requires_grad,
|
||||
False,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
True,
|
||||
)
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_swept_sphere_collision(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight=weight,
|
||||
env_query_idx=env_query_idx,
|
||||
sweep_steps=sweep_steps,
|
||||
activation_distance=activation_distance,
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
)
|
||||
d_val = dist.view(d_prim.shape) + d_prim
|
||||
return d_val
|
||||
|
||||
def clear_cache(self):
|
||||
if self._voxel_tensor_list is not None:
|
||||
self._voxel_tensor_list[2][:] = 0
|
||||
self._voxel_tensor_list[-1][:] = -1.0 * self.max_distance
|
||||
self._env_n_voxels[:] = 0
|
||||
@@ -18,6 +18,7 @@ import warp as wp
|
||||
# CuRobo
|
||||
from curobo.curobolib.kinematics import rotation_matrix_to_quaternion
|
||||
from curobo.util.logger import log_error
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.warp import init_warp
|
||||
|
||||
|
||||
@@ -27,11 +28,11 @@ def transform_points(
|
||||
if out_points is None:
|
||||
out_points = torch.zeros((points.shape[0], 3), device=points.device, dtype=points.dtype)
|
||||
if out_gp is None:
|
||||
out_gp = torch.zeros((position.shape[0], 3), device=position.device)
|
||||
out_gp = torch.zeros((position.shape[0], 3), device=position.device, dtype=points.dtype)
|
||||
if out_gq is None:
|
||||
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device)
|
||||
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device, dtype=points.dtype)
|
||||
if out_gpt is None:
|
||||
out_gpt = torch.zeros((points.shape[0], 3), device=position.device)
|
||||
out_gpt = torch.zeros((points.shape[0], 3), device=position.device, dtype=points.dtype)
|
||||
out_points = TransformPoint.apply(
|
||||
position, quaternion, points, out_points, out_gp, out_gq, out_gpt
|
||||
)
|
||||
@@ -46,18 +47,20 @@ def batch_transform_points(
|
||||
(points.shape[0], points.shape[1], 3), device=points.device, dtype=points.dtype
|
||||
)
|
||||
if out_gp is None:
|
||||
out_gp = torch.zeros((position.shape[0], 3), device=position.device)
|
||||
out_gp = torch.zeros((position.shape[0], 3), device=position.device, dtype=points.dtype)
|
||||
if out_gq is None:
|
||||
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device)
|
||||
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device, dtype=points.dtype)
|
||||
if out_gpt is None:
|
||||
out_gpt = torch.zeros((points.shape[0], points.shape[1], 3), device=position.device)
|
||||
out_gpt = torch.zeros(
|
||||
(points.shape[0], points.shape[1], 3), device=position.device, dtype=points.dtype
|
||||
)
|
||||
out_points = BatchTransformPoint.apply(
|
||||
position, quaternion, points, out_points, out_gp, out_gq, out_gpt
|
||||
)
|
||||
return out_points
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def get_inv_transform(w_rot_c, w_trans_c):
|
||||
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
||||
c_rot_w = w_rot_c.transpose(-1, -2)
|
||||
@@ -65,7 +68,7 @@ def get_inv_transform(w_rot_c, w_trans_c):
|
||||
return c_rot_w, c_trans_w
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def transform_point_inverse(point, rot, trans):
|
||||
# type: (Tensor, Tensor, Tensor) -> Tensor
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# Standard Library
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
@@ -564,6 +565,68 @@ class PointCloud(Obstacle):
|
||||
return new_spheres
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoxelGrid(Obstacle):
|
||||
dims: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0])
|
||||
voxel_size: float = 0.02 # meters
|
||||
feature_tensor: Optional[torch.Tensor] = None
|
||||
xyzr_tensor: Optional[torch.Tensor] = None
|
||||
feature_dtype: torch.dtype = torch.float32
|
||||
|
||||
def __post_init__(self):
|
||||
if self.feature_tensor is not None:
|
||||
self.feature_dtype = self.feature_tensor.dtype
|
||||
|
||||
def create_xyzr_tensor(
|
||||
self, transform_to_origin: bool = False, tensor_args: TensorDeviceType = TensorDeviceType()
|
||||
):
|
||||
bounds = self.dims
|
||||
low = [-bounds[0] / 2, -bounds[1] / 2, -bounds[2] / 2]
|
||||
high = [bounds[0] / 2, bounds[1] / 2, bounds[2] / 2]
|
||||
trange = [h - l for l, h in zip(low, high)]
|
||||
x = torch.linspace(
|
||||
low[0], high[0], int(math.floor(trange[0] / self.voxel_size)), device=tensor_args.device
|
||||
)
|
||||
y = torch.linspace(
|
||||
low[1], high[1], int(math.floor(trange[1] / self.voxel_size)), device=tensor_args.device
|
||||
)
|
||||
z = torch.linspace(
|
||||
low[2], high[2], int(math.floor(trange[2] / self.voxel_size)), device=tensor_args.device
|
||||
)
|
||||
w, l, h = x.shape[0], y.shape[0], z.shape[0]
|
||||
xyz = (
|
||||
torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3)
|
||||
)
|
||||
|
||||
if transform_to_origin:
|
||||
pose = Pose.from_list(self.pose, tensor_args=tensor_args)
|
||||
xyz = pose.transform_points(xyz.contiguous())
|
||||
r = torch.zeros_like(xyz[:, 0:1]) + (self.voxel_size * 0.5)
|
||||
xyzr = torch.cat([xyz, r], dim=1)
|
||||
return xyzr
|
||||
|
||||
def get_occupied_voxels(self, feature_threshold: Optional[float] = None):
|
||||
if feature_threshold is None:
|
||||
feature_threshold = -1.0 * self.voxel_size
|
||||
if self.xyzr_tensor is None or self.feature_tensor is None:
|
||||
log_error("Feature tensor or xyzr tensor is empty")
|
||||
xyzr = self.xyzr_tensor.clone()
|
||||
xyzr[:, 3] = self.feature_tensor
|
||||
occupied = xyzr[self.feature_tensor > feature_threshold]
|
||||
return occupied
|
||||
|
||||
def clone(self):
|
||||
return VoxelGrid(
|
||||
name=self.name,
|
||||
pose=self.pose.copy(),
|
||||
dims=self.dims.copy(),
|
||||
feature_tensor=self.feature_tensor.clone() if self.feature_tensor is not None else None,
|
||||
xyzr_tensor=self.xyzr_tensor.clone() if self.xyzr_tensor is not None else None,
|
||||
feature_dtype=self.feature_dtype,
|
||||
voxel_size=self.voxel_size,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorldConfig(Sequence):
|
||||
"""Representation of World for use in CuRobo."""
|
||||
@@ -586,25 +649,13 @@ class WorldConfig(Sequence):
|
||||
#: BloxMap obstacle.
|
||||
blox: Optional[List[BloxMap]] = None
|
||||
|
||||
voxel: Optional[List[VoxelGrid]] = None
|
||||
|
||||
#: List of all obstacles in world.
|
||||
objects: Optional[List[Obstacle]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# create objects list:
|
||||
if self.objects is None:
|
||||
self.objects = []
|
||||
if self.sphere is not None:
|
||||
self.objects += self.sphere
|
||||
if self.cuboid is not None:
|
||||
self.objects += self.cuboid
|
||||
if self.capsule is not None:
|
||||
self.objects += self.capsule
|
||||
if self.mesh is not None:
|
||||
self.objects += self.mesh
|
||||
if self.blox is not None:
|
||||
self.objects += self.blox
|
||||
if self.cylinder is not None:
|
||||
self.objects += self.cylinder
|
||||
if self.sphere is None:
|
||||
self.sphere = []
|
||||
if self.cuboid is None:
|
||||
@@ -617,6 +668,18 @@ class WorldConfig(Sequence):
|
||||
self.cylinder = []
|
||||
if self.blox is None:
|
||||
self.blox = []
|
||||
if self.voxel is None:
|
||||
self.voxel = []
|
||||
if self.objects is None:
|
||||
self.objects = (
|
||||
self.sphere
|
||||
+ self.cuboid
|
||||
+ self.capsule
|
||||
+ self.mesh
|
||||
+ self.cylinder
|
||||
+ self.blox
|
||||
+ self.voxel
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.objects)
|
||||
@@ -632,6 +695,7 @@ class WorldConfig(Sequence):
|
||||
capsule=self.capsule.copy() if self.capsule is not None else None,
|
||||
cylinder=self.cylinder.copy() if self.cylinder is not None else None,
|
||||
blox=self.blox.copy() if self.blox is not None else None,
|
||||
voxel=self.voxel.copy() if self.voxel is not None else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -642,6 +706,7 @@ class WorldConfig(Sequence):
|
||||
mesh = None
|
||||
blox = None
|
||||
cylinder = None
|
||||
voxel = None
|
||||
# load yaml:
|
||||
if "cuboid" in data_dict.keys():
|
||||
cuboid = [Cuboid(name=x, **data_dict["cuboid"][x]) for x in data_dict["cuboid"]]
|
||||
@@ -655,6 +720,8 @@ class WorldConfig(Sequence):
|
||||
cylinder = [Cylinder(name=x, **data_dict["cylinder"][x]) for x in data_dict["cylinder"]]
|
||||
if "blox" in data_dict.keys():
|
||||
blox = [BloxMap(name=x, **data_dict["blox"][x]) for x in data_dict["blox"]]
|
||||
if "voxel" in data_dict.keys():
|
||||
voxel = [VoxelGrid(name=x, **data_dict["voxel"][x]) for x in data_dict["voxel"]]
|
||||
|
||||
return WorldConfig(
|
||||
cuboid=cuboid,
|
||||
@@ -663,6 +730,7 @@ class WorldConfig(Sequence):
|
||||
cylinder=cylinder,
|
||||
mesh=mesh,
|
||||
blox=blox,
|
||||
voxel=voxel,
|
||||
)
|
||||
|
||||
# load world config as obbs: convert all types to obbs
|
||||
@@ -688,6 +756,10 @@ class WorldConfig(Sequence):
|
||||
|
||||
if current_world.mesh is not None and len(current_world.mesh) > 0:
|
||||
mesh_obb = [x.get_cuboid() for x in current_world.mesh]
|
||||
|
||||
if current_world.voxel is not None and len(current_world.voxel) > 0:
|
||||
log_error("VoxelGrid cannot be converted to obb world")
|
||||
|
||||
return WorldConfig(
|
||||
cuboid=cuboid_obb + sphere_obb + capsule_obb + cylinder_obb + mesh_obb + blox_obb
|
||||
)
|
||||
@@ -714,6 +786,8 @@ class WorldConfig(Sequence):
|
||||
for i in range(len(current_world.blox)):
|
||||
if current_world.blox[i].mesh is not None:
|
||||
blox_obb.append(current_world.blox[i].get_mesh(process=process))
|
||||
if current_world.voxel is not None and len(current_world.voxel) > 0:
|
||||
log_error("VoxelGrid cannot be converted to mesh world")
|
||||
|
||||
return WorldConfig(
|
||||
mesh=current_world.mesh
|
||||
@@ -750,6 +824,7 @@ class WorldConfig(Sequence):
|
||||
return WorldConfig(
|
||||
mesh=current_world.mesh + sphere_obb + capsule_obb + cylinder_obb + blox_obb,
|
||||
cuboid=cuboid_obb,
|
||||
voxel=current_world.voxel,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -822,6 +897,8 @@ class WorldConfig(Sequence):
|
||||
self.cylinder.append(obstacle)
|
||||
elif isinstance(obstacle, Capsule):
|
||||
self.capsule.append(obstacle)
|
||||
elif isinstance(obstacle, VoxelGrid):
|
||||
self.voxel.append(obstacle)
|
||||
else:
|
||||
ValueError("Obstacle type not supported")
|
||||
self.objects.append(obstacle)
|
||||
|
||||
Reference in New Issue
Block a user