kinematics refactor with mimic joint support
This commit is contained in:
@@ -263,13 +263,16 @@ class WorldCollisionConfig:
|
||||
cache: Optional[Dict[Obstacle, int]] = None
|
||||
n_envs: int = 1
|
||||
checker_type: CollisionCheckerType = CollisionCheckerType.PRIMITIVE
|
||||
max_distance: Union[torch.Tensor, float] = 0.01
|
||||
max_distance: Union[torch.Tensor, float] = 0.1
|
||||
max_esdf_distance: Union[torch.Tensor, float] = 1000.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.world_model is not None and isinstance(self.world_model, list):
|
||||
self.n_envs = len(self.world_model)
|
||||
if isinstance(self.max_distance, float):
|
||||
self.max_distance = self.tensor_args.to_device([self.max_distance])
|
||||
if isinstance(self.max_esdf_distance, float):
|
||||
self.max_esdf_distance = self.tensor_args.to_device([self.max_esdf_distance])
|
||||
|
||||
@staticmethod
|
||||
def load_from_dict(
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch.autograd.profiler as profiler
|
||||
|
||||
# CuRobo
|
||||
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollisionConfig
|
||||
from curobo.geom.sdf.world_mesh import WorldMeshCollision
|
||||
from curobo.geom.sdf.world_voxel import WorldVoxelCollision
|
||||
from curobo.geom.types import Cuboid, Mesh, Sphere, SphereFitType, WorldConfig
|
||||
from curobo.types.camera import CameraObservation
|
||||
from curobo.types.math import Pose
|
||||
@@ -33,7 +33,7 @@ except ImportError:
|
||||
from abc import ABC as Mapper
|
||||
|
||||
|
||||
class WorldBloxCollision(WorldMeshCollision):
|
||||
class WorldBloxCollision(WorldVoxelCollision):
|
||||
"""World Collision Representaiton using Nvidia's nvblox library.
|
||||
|
||||
This class depends on pytorch wrapper for nvblox.
|
||||
@@ -127,6 +127,7 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
collision_query_buffer: CollisionQueryBuffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
d = self._blox_mapper.query_sphere_sdf_cost(
|
||||
query_spheres,
|
||||
@@ -196,6 +197,7 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
collision_query_buffer,
|
||||
weight=weight,
|
||||
activation_distance=activation_distance,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
|
||||
if ("primitive" not in self.collision_types or not self.collision_types["primitive"]) and (
|
||||
@@ -227,6 +229,7 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx=None,
|
||||
return_loss: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if "blox" not in self.collision_types or not self.collision_types["blox"]:
|
||||
return super().get_sphere_collision(
|
||||
|
||||
@@ -48,11 +48,10 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
dims = voxel_cache["dims"]
|
||||
voxel_size = voxel_cache["voxel_size"]
|
||||
feature_dtype = voxel_cache["feature_dtype"]
|
||||
n_voxels = int(
|
||||
math.floor(dims[0] / voxel_size)
|
||||
* math.floor(dims[1] / voxel_size)
|
||||
* math.floor(dims[2] / voxel_size)
|
||||
)
|
||||
grid_shape = VoxelGrid(
|
||||
"test", pose=[0, 0, 0, 1, 0, 0, 0], dims=dims, voxel_size=voxel_size
|
||||
).get_grid_shape()[0]
|
||||
n_voxels = grid_shape[0] * grid_shape[1] * grid_shape[2]
|
||||
|
||||
voxel_params = torch.zeros(
|
||||
(self.n_envs, n_layers, 4),
|
||||
@@ -77,6 +76,12 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
dtype=feature_dtype,
|
||||
)
|
||||
|
||||
if feature_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
voxel_features[:] = -1.0 * self.max_esdf_distance
|
||||
else:
|
||||
voxel_features = (voxel_features.to(dtype=torch.float16) - self.max_esdf_distance).to(
|
||||
dtype=feature_dtype
|
||||
)
|
||||
self._voxel_tensor_list = [voxel_params, voxel_pose, voxel_enable, voxel_features]
|
||||
self.collision_types["voxel"] = True
|
||||
self._env_voxel_names = [[None for _ in range(n_layers)] for _ in range(self.n_envs)]
|
||||
@@ -84,14 +89,14 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
def load_collision_model(
|
||||
self, world_model: WorldConfig, env_idx=0, fix_cache_reference: bool = False
|
||||
):
|
||||
self._load_collision_model_in_cache(
|
||||
self._load_voxel_collision_model_in_cache(
|
||||
world_model, env_idx, fix_cache_reference=fix_cache_reference
|
||||
)
|
||||
return super().load_collision_model(
|
||||
world_model, env_idx=env_idx, fix_cache_reference=fix_cache_reference
|
||||
)
|
||||
|
||||
def _load_collision_model_in_cache(
|
||||
def _load_voxel_collision_model_in_cache(
|
||||
self, world_config: WorldConfig, env_idx: int = 0, fix_cache_reference: bool = False
|
||||
):
|
||||
"""TODO:
|
||||
@@ -396,9 +401,10 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
|
||||
b, h, n, _ = query_sphere.shape # This can be read from collision query buffer
|
||||
use_batch_env = True
|
||||
env_query_idx_voxel = env_query_idx
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
env_query_idx_voxel = self._env_n_voxels
|
||||
dist = SdfSphereVoxel.apply(
|
||||
query_sphere,
|
||||
collision_query_buffer.voxel_collision_buffer.distance_buffer,
|
||||
@@ -406,13 +412,13 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self.max_esdf_distance,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
env_query_idx_voxel,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
@@ -424,12 +430,8 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
sum_collisions,
|
||||
compute_esdf,
|
||||
)
|
||||
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
if ("primitive" not in self.collision_types or not self.collision_types["primitive"]) and (
|
||||
"mesh" not in self.collision_types or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_sphere_distance(
|
||||
@@ -443,9 +445,10 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
if compute_esdf:
|
||||
|
||||
d_val = torch.maximum(dist.view(d_prim.shape), d_prim)
|
||||
else:
|
||||
d_val = d_val.view(d_prim.shape) + d_prim
|
||||
d_val = dist.view(d_prim.shape) + d_prim
|
||||
|
||||
return d_val
|
||||
|
||||
@@ -473,9 +476,10 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
raise ValueError("cannot return loss for classification, use get_sphere_distance")
|
||||
b, h, n, _ = query_sphere.shape
|
||||
use_batch_env = True
|
||||
env_query_idx_voxel = env_query_idx
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
env_query_idx_voxel = self._env_n_voxels
|
||||
dist = SdfSphereVoxel.apply(
|
||||
query_sphere,
|
||||
collision_query_buffer.voxel_collision_buffer.distance_buffer,
|
||||
@@ -483,13 +487,13 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self.max_esdf_distance,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
env_query_idx_voxel,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
@@ -501,11 +505,8 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
True,
|
||||
False,
|
||||
)
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
if ("primitive" not in self.collision_types or not self.collision_types["primitive"]) and (
|
||||
"mesh" not in self.collision_types or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_sphere_collision(
|
||||
@@ -552,9 +553,10 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
)
|
||||
b, h, n, _ = query_sphere.shape
|
||||
use_batch_env = True
|
||||
env_query_idx_voxel = env_query_idx
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
env_query_idx_voxel = self._env_n_voxels
|
||||
|
||||
dist = SdfSweptSphereVoxel.apply(
|
||||
query_sphere,
|
||||
@@ -563,14 +565,14 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self.max_esdf_distance,
|
||||
speed_dt,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
env_query_idx_voxel,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
@@ -583,12 +585,8 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
return_loss,
|
||||
sum_collisions,
|
||||
)
|
||||
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
if ("primitive" not in self.collision_types or not self.collision_types["primitive"]) and (
|
||||
"mesh" not in self.collision_types or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_swept_sphere_distance(
|
||||
@@ -641,9 +639,10 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
b, h, n, _ = query_sphere.shape
|
||||
|
||||
use_batch_env = True
|
||||
env_query_idx_voxel = env_query_idx
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
env_query_idx_voxel = self._env_n_voxels
|
||||
dist = SdfSweptSphereVoxel.apply(
|
||||
query_sphere,
|
||||
collision_query_buffer.voxel_collision_buffer.distance_buffer,
|
||||
@@ -651,14 +650,14 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self.max_esdf_distance,
|
||||
speed_dt,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
env_query_idx_voxel,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
@@ -671,11 +670,8 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
return_loss,
|
||||
True,
|
||||
)
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
if ("primitive" not in self.collision_types or not self.collision_types["primitive"]) and (
|
||||
"mesh" not in self.collision_types or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_swept_sphere_collision(
|
||||
@@ -695,5 +691,15 @@ class WorldVoxelCollision(WorldMeshCollision):
|
||||
def clear_cache(self):
|
||||
if self._voxel_tensor_list is not None:
|
||||
self._voxel_tensor_list[2][:] = 0
|
||||
self._voxel_tensor_list[-1][:] = -1.0 * self.max_distance
|
||||
if self._voxel_tensor_list[3].dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
self._voxel_tensor_list[3][:] = -1.0 * self.max_esdf_distance
|
||||
else:
|
||||
self._voxel_tensor_list[3][:] = (
|
||||
self._voxel_tensor_list[3].to(dtype=torch.float16) * 0.0
|
||||
- self.max_esdf_distance
|
||||
).to(dtype=self._voxel_tensor_list[3].dtype)
|
||||
self._env_n_voxels[:] = 0
|
||||
print(self._voxel_tensor_list)
|
||||
|
||||
def get_voxel_grid_shape(self, env_idx: int = 0, obs_idx: int = 0):
|
||||
return self._voxel_tensor_list[3][env_idx, obs_idx].shape
|
||||
|
||||
@@ -577,22 +577,24 @@ class VoxelGrid(Obstacle):
|
||||
if self.feature_tensor is not None:
|
||||
self.feature_dtype = self.feature_tensor.dtype
|
||||
|
||||
def create_xyzr_tensor(
|
||||
self, transform_to_origin: bool = False, tensor_args: TensorDeviceType = TensorDeviceType()
|
||||
):
|
||||
def get_grid_shape(self):
|
||||
bounds = self.dims
|
||||
low = [-bounds[0] / 2, -bounds[1] / 2, -bounds[2] / 2]
|
||||
high = [bounds[0] / 2, bounds[1] / 2, bounds[2] / 2]
|
||||
trange = [h - l for l, h in zip(low, high)]
|
||||
x = torch.linspace(
|
||||
low[0], high[0], int(math.floor(trange[0] / self.voxel_size)), device=tensor_args.device
|
||||
)
|
||||
y = torch.linspace(
|
||||
low[1], high[1], int(math.floor(trange[1] / self.voxel_size)), device=tensor_args.device
|
||||
)
|
||||
z = torch.linspace(
|
||||
low[2], high[2], int(math.floor(trange[2] / self.voxel_size)), device=tensor_args.device
|
||||
)
|
||||
grid_shape = [
|
||||
1 + int(high[i] / self.voxel_size) - (int(low[i] / self.voxel_size))
|
||||
for i in range(len(low))
|
||||
]
|
||||
return grid_shape, low, high
|
||||
|
||||
def create_xyzr_tensor(
|
||||
self, transform_to_origin: bool = False, tensor_args: TensorDeviceType = TensorDeviceType()
|
||||
):
|
||||
trange, low, high = self.get_grid_shape()
|
||||
|
||||
x = torch.linspace(low[0], high[0], trange[0], device=tensor_args.device)
|
||||
y = torch.linspace(low[1], high[1], trange[1], device=tensor_args.device)
|
||||
z = torch.linspace(low[2], high[2], trange[2], device=tensor_args.device)
|
||||
w, l, h = x.shape[0], y.shape[0], z.shape[0]
|
||||
xyz = (
|
||||
torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3)
|
||||
@@ -603,11 +605,12 @@ class VoxelGrid(Obstacle):
|
||||
xyz = pose.transform_points(xyz.contiguous())
|
||||
r = torch.zeros_like(xyz[:, 0:1]) + (self.voxel_size * 0.5)
|
||||
xyzr = torch.cat([xyz, r], dim=1)
|
||||
|
||||
return xyzr
|
||||
|
||||
def get_occupied_voxels(self, feature_threshold: Optional[float] = None):
|
||||
if feature_threshold is None:
|
||||
feature_threshold = -1.0 * self.voxel_size
|
||||
feature_threshold = -0.5 * self.voxel_size
|
||||
if self.xyzr_tensor is None or self.feature_tensor is None:
|
||||
log_error("Feature tensor or xyzr tensor is empty")
|
||||
xyzr = self.xyzr_tensor.clone()
|
||||
|
||||
Reference in New Issue
Block a user