kinematics refactor with mimic joint support

This commit is contained in:
Balakumar Sundaralingam
2024-04-03 18:42:01 -07:00
parent b481ee201a
commit 774dcfd609
60 changed files with 2177 additions and 810 deletions

View File

@@ -263,13 +263,16 @@ class WorldCollisionConfig:
cache: Optional[Dict[Obstacle, int]] = None
n_envs: int = 1
checker_type: CollisionCheckerType = CollisionCheckerType.PRIMITIVE
max_distance: Union[torch.Tensor, float] = 0.01
max_distance: Union[torch.Tensor, float] = 0.1
max_esdf_distance: Union[torch.Tensor, float] = 1000.0
def __post_init__(self):
if self.world_model is not None and isinstance(self.world_model, list):
self.n_envs = len(self.world_model)
if isinstance(self.max_distance, float):
self.max_distance = self.tensor_args.to_device([self.max_distance])
if isinstance(self.max_esdf_distance, float):
self.max_esdf_distance = self.tensor_args.to_device([self.max_esdf_distance])
@staticmethod
def load_from_dict(

View File

@@ -18,7 +18,7 @@ import torch.autograd.profiler as profiler
# CuRobo
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollisionConfig
from curobo.geom.sdf.world_mesh import WorldMeshCollision
from curobo.geom.sdf.world_voxel import WorldVoxelCollision
from curobo.geom.types import Cuboid, Mesh, Sphere, SphereFitType, WorldConfig
from curobo.types.camera import CameraObservation
from curobo.types.math import Pose
@@ -33,7 +33,7 @@ except ImportError:
from abc import ABC as Mapper
class WorldBloxCollision(WorldMeshCollision):
class WorldBloxCollision(WorldVoxelCollision):
"""World Collision Representaiton using Nvidia's nvblox library.
This class depends on pytorch wrapper for nvblox.
@@ -127,6 +127,7 @@ class WorldBloxCollision(WorldMeshCollision):
collision_query_buffer: CollisionQueryBuffer,
weight,
activation_distance,
compute_esdf: bool = False,
):
d = self._blox_mapper.query_sphere_sdf_cost(
query_spheres,
@@ -196,6 +197,7 @@ class WorldBloxCollision(WorldMeshCollision):
collision_query_buffer,
weight=weight,
activation_distance=activation_distance,
compute_esdf=compute_esdf,
)
if ("primitive" not in self.collision_types or not self.collision_types["primitive"]) and (
@@ -227,6 +229,7 @@ class WorldBloxCollision(WorldMeshCollision):
activation_distance: torch.Tensor,
env_query_idx=None,
return_loss: bool = False,
**kwargs,
):
if "blox" not in self.collision_types or not self.collision_types["blox"]:
return super().get_sphere_collision(

View File

@@ -48,11 +48,10 @@ class WorldVoxelCollision(WorldMeshCollision):
dims = voxel_cache["dims"]
voxel_size = voxel_cache["voxel_size"]
feature_dtype = voxel_cache["feature_dtype"]
n_voxels = int(
math.floor(dims[0] / voxel_size)
* math.floor(dims[1] / voxel_size)
* math.floor(dims[2] / voxel_size)
)
grid_shape = VoxelGrid(
"test", pose=[0, 0, 0, 1, 0, 0, 0], dims=dims, voxel_size=voxel_size
).get_grid_shape()[0]
n_voxels = grid_shape[0] * grid_shape[1] * grid_shape[2]
voxel_params = torch.zeros(
(self.n_envs, n_layers, 4),
@@ -77,6 +76,12 @@ class WorldVoxelCollision(WorldMeshCollision):
dtype=feature_dtype,
)
if feature_dtype in [torch.float32, torch.float16, torch.bfloat16]:
voxel_features[:] = -1.0 * self.max_esdf_distance
else:
voxel_features = (voxel_features.to(dtype=torch.float16) - self.max_esdf_distance).to(
dtype=feature_dtype
)
self._voxel_tensor_list = [voxel_params, voxel_pose, voxel_enable, voxel_features]
self.collision_types["voxel"] = True
self._env_voxel_names = [[None for _ in range(n_layers)] for _ in range(self.n_envs)]
@@ -84,14 +89,14 @@ class WorldVoxelCollision(WorldMeshCollision):
def load_collision_model(
self, world_model: WorldConfig, env_idx=0, fix_cache_reference: bool = False
):
self._load_collision_model_in_cache(
self._load_voxel_collision_model_in_cache(
world_model, env_idx, fix_cache_reference=fix_cache_reference
)
return super().load_collision_model(
world_model, env_idx=env_idx, fix_cache_reference=fix_cache_reference
)
def _load_collision_model_in_cache(
def _load_voxel_collision_model_in_cache(
self, world_config: WorldConfig, env_idx: int = 0, fix_cache_reference: bool = False
):
"""TODO:
@@ -396,9 +401,10 @@ class WorldVoxelCollision(WorldMeshCollision):
b, h, n, _ = query_sphere.shape # This can be read from collision query buffer
use_batch_env = True
env_query_idx_voxel = env_query_idx
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
env_query_idx_voxel = self._env_n_voxels
dist = SdfSphereVoxel.apply(
query_sphere,
collision_query_buffer.voxel_collision_buffer.distance_buffer,
@@ -406,13 +412,13 @@ class WorldVoxelCollision(WorldMeshCollision):
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self.max_esdf_distance,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
env_query_idx_voxel,
self._voxel_tensor_list[0].shape[1],
b,
h,
@@ -424,12 +430,8 @@ class WorldVoxelCollision(WorldMeshCollision):
sum_collisions,
compute_esdf,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
if ("primitive" not in self.collision_types or not self.collision_types["primitive"]) and (
"mesh" not in self.collision_types or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_sphere_distance(
@@ -443,9 +445,10 @@ class WorldVoxelCollision(WorldMeshCollision):
compute_esdf=compute_esdf,
)
if compute_esdf:
d_val = torch.maximum(dist.view(d_prim.shape), d_prim)
else:
d_val = d_val.view(d_prim.shape) + d_prim
d_val = dist.view(d_prim.shape) + d_prim
return d_val
@@ -473,9 +476,10 @@ class WorldVoxelCollision(WorldMeshCollision):
raise ValueError("cannot return loss for classification, use get_sphere_distance")
b, h, n, _ = query_sphere.shape
use_batch_env = True
env_query_idx_voxel = env_query_idx
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
env_query_idx_voxel = self._env_n_voxels
dist = SdfSphereVoxel.apply(
query_sphere,
collision_query_buffer.voxel_collision_buffer.distance_buffer,
@@ -483,13 +487,13 @@ class WorldVoxelCollision(WorldMeshCollision):
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self.max_esdf_distance,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
env_query_idx_voxel,
self._voxel_tensor_list[0].shape[1],
b,
h,
@@ -501,11 +505,8 @@ class WorldVoxelCollision(WorldMeshCollision):
True,
False,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
if ("primitive" not in self.collision_types or not self.collision_types["primitive"]) and (
"mesh" not in self.collision_types or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_sphere_collision(
@@ -552,9 +553,10 @@ class WorldVoxelCollision(WorldMeshCollision):
)
b, h, n, _ = query_sphere.shape
use_batch_env = True
env_query_idx_voxel = env_query_idx
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
env_query_idx_voxel = self._env_n_voxels
dist = SdfSweptSphereVoxel.apply(
query_sphere,
@@ -563,14 +565,14 @@ class WorldVoxelCollision(WorldMeshCollision):
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self.max_esdf_distance,
speed_dt,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
env_query_idx_voxel,
self._voxel_tensor_list[0].shape[1],
b,
h,
@@ -583,12 +585,8 @@ class WorldVoxelCollision(WorldMeshCollision):
return_loss,
sum_collisions,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
if ("primitive" not in self.collision_types or not self.collision_types["primitive"]) and (
"mesh" not in self.collision_types or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_swept_sphere_distance(
@@ -641,9 +639,10 @@ class WorldVoxelCollision(WorldMeshCollision):
b, h, n, _ = query_sphere.shape
use_batch_env = True
env_query_idx_voxel = env_query_idx
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
env_query_idx_voxel = self._env_n_voxels
dist = SdfSweptSphereVoxel.apply(
query_sphere,
collision_query_buffer.voxel_collision_buffer.distance_buffer,
@@ -651,14 +650,14 @@ class WorldVoxelCollision(WorldMeshCollision):
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self.max_esdf_distance,
speed_dt,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
env_query_idx_voxel,
self._voxel_tensor_list[0].shape[1],
b,
h,
@@ -671,11 +670,8 @@ class WorldVoxelCollision(WorldMeshCollision):
return_loss,
True,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
if ("primitive" not in self.collision_types or not self.collision_types["primitive"]) and (
"mesh" not in self.collision_types or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_swept_sphere_collision(
@@ -695,5 +691,15 @@ class WorldVoxelCollision(WorldMeshCollision):
def clear_cache(self):
if self._voxel_tensor_list is not None:
self._voxel_tensor_list[2][:] = 0
self._voxel_tensor_list[-1][:] = -1.0 * self.max_distance
if self._voxel_tensor_list[3].dtype in [torch.float32, torch.float16, torch.bfloat16]:
self._voxel_tensor_list[3][:] = -1.0 * self.max_esdf_distance
else:
self._voxel_tensor_list[3][:] = (
self._voxel_tensor_list[3].to(dtype=torch.float16) * 0.0
- self.max_esdf_distance
).to(dtype=self._voxel_tensor_list[3].dtype)
self._env_n_voxels[:] = 0
print(self._voxel_tensor_list)
def get_voxel_grid_shape(self, env_idx: int = 0, obs_idx: int = 0):
return self._voxel_tensor_list[3][env_idx, obs_idx].shape

View File

@@ -577,22 +577,24 @@ class VoxelGrid(Obstacle):
if self.feature_tensor is not None:
self.feature_dtype = self.feature_tensor.dtype
def create_xyzr_tensor(
self, transform_to_origin: bool = False, tensor_args: TensorDeviceType = TensorDeviceType()
):
def get_grid_shape(self):
bounds = self.dims
low = [-bounds[0] / 2, -bounds[1] / 2, -bounds[2] / 2]
high = [bounds[0] / 2, bounds[1] / 2, bounds[2] / 2]
trange = [h - l for l, h in zip(low, high)]
x = torch.linspace(
low[0], high[0], int(math.floor(trange[0] / self.voxel_size)), device=tensor_args.device
)
y = torch.linspace(
low[1], high[1], int(math.floor(trange[1] / self.voxel_size)), device=tensor_args.device
)
z = torch.linspace(
low[2], high[2], int(math.floor(trange[2] / self.voxel_size)), device=tensor_args.device
)
grid_shape = [
1 + int(high[i] / self.voxel_size) - (int(low[i] / self.voxel_size))
for i in range(len(low))
]
return grid_shape, low, high
def create_xyzr_tensor(
self, transform_to_origin: bool = False, tensor_args: TensorDeviceType = TensorDeviceType()
):
trange, low, high = self.get_grid_shape()
x = torch.linspace(low[0], high[0], trange[0], device=tensor_args.device)
y = torch.linspace(low[1], high[1], trange[1], device=tensor_args.device)
z = torch.linspace(low[2], high[2], trange[2], device=tensor_args.device)
w, l, h = x.shape[0], y.shape[0], z.shape[0]
xyz = (
torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3)
@@ -603,11 +605,12 @@ class VoxelGrid(Obstacle):
xyz = pose.transform_points(xyz.contiguous())
r = torch.zeros_like(xyz[:, 0:1]) + (self.voxel_size * 0.5)
xyzr = torch.cat([xyz, r], dim=1)
return xyzr
def get_occupied_voxels(self, feature_threshold: Optional[float] = None):
if feature_threshold is None:
feature_threshold = -1.0 * self.voxel_size
feature_threshold = -0.5 * self.voxel_size
if self.xyzr_tensor is None or self.feature_tensor is None:
log_error("Feature tensor or xyzr tensor is empty")
xyzr = self.xyzr_tensor.clone()