Files
gen_data_curobo/src/curobo/geom/sdf/world_mesh.py
Balakumar Sundaralingam 07e6ccfc91 release repository
2023-10-26 04:17:19 -07:00

600 lines
22 KiB
Python

#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from typing import List, Optional
# Third Party
import numpy as np
import torch
import warp as wp
# CuRobo
from curobo.geom.sdf.warp_primitives import SdfMeshWarpPy, SweptSdfMeshWarpPy
from curobo.geom.sdf.world import (
CollisionQueryBuffer,
WorldCollisionConfig,
WorldPrimitiveCollision,
)
from curobo.geom.types import Mesh, WorldConfig
from curobo.types.math import Pose
from curobo.util.logger import log_error, log_info, log_warn
from curobo.util.warp import init_warp
@dataclass(frozen=True)
class WarpMeshData:
name: str
m_id: int
vertices: wp.array
faces: wp.array
mesh: wp.Mesh
class WorldMeshCollision(WorldPrimitiveCollision):
"""World Mesh Collision using Nvidia's warp library
This currently requires passing int64 array from torch to warp which is only
available when compiled from source.
"""
def __init__(self, config: WorldCollisionConfig):
# WorldCollision.(self)
init_warp()
self.tensor_args = config.tensor_args
self._env_n_mesh = None
self._mesh_tensor_list = None
self._env_mesh_names = None
self._wp_device = wp.torch.device_from_torch(self.tensor_args.device)
self._wp_mesh_cache = {} # stores warp meshes across environments
super().__init__(config)
def _init_cache(self):
if (
self.cache is not None
and "mesh" in self.cache
and (not self.cache["mesh"] in [None, 0])
):
self._create_mesh_cache(self.cache["mesh"])
return super()._init_cache()
def load_collision_model(
self, world_model: WorldConfig, env_idx: int = 0, load_obb_obs: bool = True
):
max_nmesh = len(world_model.mesh)
if max_nmesh > 0:
if self._mesh_tensor_list is None or self._mesh_tensor_list[0].shape[1] < max_nmesh:
log_warn("Creating new Mesh cache: " + str(max_nmesh))
self._create_mesh_cache(max_nmesh)
# load all meshes as a batch:
name_list, w_mid, w_inv_pose = self._load_batch_mesh_to_warp(world_model.mesh)
self._mesh_tensor_list[0][env_idx, :max_nmesh] = w_mid
self._mesh_tensor_list[1][env_idx, :max_nmesh, :7] = w_inv_pose
self._mesh_tensor_list[2][env_idx, :max_nmesh] = 1
self._env_mesh_names[env_idx][:max_nmesh] = name_list
self._env_n_mesh[env_idx] = max_nmesh
self.collision_types["mesh"] = True
if load_obb_obs:
super().load_collision_model(world_model, env_idx)
else:
self.world_model = world_model
def load_batch_collision_model(self, world_config_list: List[WorldConfig]):
max_nmesh = max([len(x.mesh) for x in world_config_list])
if self._mesh_tensor_list is None or self._mesh_tensor_list[0].shape[1] < max_nmesh:
log_info("Creating new Mesh cache: " + str(max_nmesh))
self._create_mesh_cache(max_nmesh)
for env_idx, world_model in enumerate(world_config_list):
self.load_collision_model(world_model, env_idx=env_idx, load_obb_obs=False)
super().load_batch_collision_model(world_config_list)
def _load_mesh_to_warp(self, mesh: Mesh):
verts, faces = mesh.get_mesh_data()
v = wp.array(verts, dtype=wp.vec3, device=self._wp_device)
f = wp.array(np.ravel(faces), dtype=int, device=self._wp_device)
new_mesh = wp.Mesh(points=v, indices=f)
return WarpMeshData(mesh.name, new_mesh.id, v, f, new_mesh)
def _load_mesh_into_cache(self, mesh: Mesh) -> WarpMeshData:
#
if mesh.name not in self._wp_mesh_cache:
# load mesh into cache:
self._wp_mesh_cache[mesh.name] = self._load_mesh_to_warp(mesh)
# return self._wp_mesh_cache[mesh.name]
else:
log_info("Object already in warp cache, using existing instance: " + mesh.name)
return self._wp_mesh_cache[mesh.name]
def _load_batch_mesh_to_warp(self, mesh_list: List[Mesh]):
# First load all verts and faces:
name_list = []
pose_list = []
id_list = torch.zeros((len(mesh_list)), device=self.tensor_args.device, dtype=torch.int64)
for i, m_idx in enumerate(mesh_list):
m_data = self._load_mesh_into_cache(m_idx)
pose_list.append(m_idx.pose)
id_list[i] = m_data.m_id
name_list.append(m_data.name)
pose_buffer = Pose.from_batch_list(pose_list, self.tensor_args)
inv_pose_buffer = pose_buffer.inverse()
return name_list, id_list, inv_pose_buffer.get_pose_vector()
def add_mesh(self, new_mesh: Mesh, env_idx: int = 0):
if self._env_n_mesh[env_idx] >= self._mesh_tensor_list[0].shape[1]:
log_error(
"Cannot add new mesh as we are at mesh cache limit, increase cache limit in WorldMeshCollision"
)
return
wp_mesh_data = self._load_mesh_into_cache(new_mesh)
# get mesh pose:
w_obj_pose = Pose.from_list(new_mesh.pose, self.tensor_args)
# add loaded mesh into scene:
curr_idx = self._env_n_mesh[env_idx]
self._mesh_tensor_list[0][env_idx, curr_idx] = wp_mesh_data.m_id
self._mesh_tensor_list[1][env_idx, curr_idx, :7] = w_obj_pose.inverse().get_pose_vector()
self._mesh_tensor_list[2][env_idx, curr_idx] = 1
self._env_mesh_names[env_idx][curr_idx] = wp_mesh_data.name
self._env_n_mesh[env_idx] = curr_idx + 1
def get_mesh_idx(
self,
name: str,
env_idx: int = 0,
) -> int:
if name not in self._env_mesh_names[env_idx]:
log_error("Obstacle with name: " + name + " not found in current world", exc_info=True)
return self._env_mesh_names[env_idx].index(name)
def create_collision_cache(self, mesh_cache=None, obb_cache=None, n_envs=None):
if n_envs is not None:
self.n_envs = n_envs
if mesh_cache is not None:
self._create_mesh_cache(mesh_cache)
if obb_cache is not None:
self._create_obb_cache(obb_cache)
def _create_mesh_cache(self, mesh_cache):
# create cache to store meshes, mesh poses and inverse poses
self._env_n_mesh = torch.zeros(
(self.n_envs), device=self.tensor_args.device, dtype=torch.int32
)
obs_enable = torch.zeros(
(self.n_envs, mesh_cache), dtype=torch.uint8, device=self.tensor_args.device
)
obs_inverse_pose = torch.zeros(
(self.n_envs, mesh_cache, 8),
dtype=self.tensor_args.dtype,
device=self.tensor_args.device,
)
obs_ids = torch.zeros(
(self.n_envs, mesh_cache), device=self.tensor_args.device, dtype=torch.int64
)
# v_empty = [[None for _ in range(mesh_cache)] for _ in range(self.n_envs)]
# @f_empty = [[None for _ in range(mesh_cache)] for _ in range(self.n_envs)]
# wp_m_empty = [[None for _ in range(mesh_cache)] for _ in range(self.n_envs)]
# warp requires uint64 for mesh indices, supports conversion from int64 to uint64
self._mesh_tensor_list = [
obs_ids,
obs_inverse_pose,
obs_enable,
] # 0=mesh idx, 1=pose, 2=mesh enable
self.collision_types["mesh"] = True # TODO: enable this after loading first mesh
self._env_mesh_names = [[None for _ in range(mesh_cache)] for _ in range(self.n_envs)]
self._wp_mesh_cache = {}
def update_mesh_pose(
self,
w_obj_pose: Optional[Pose] = None,
obj_w_pose: Optional[Pose] = None,
name: Optional[str] = None,
env_obj_idx: Optional[torch.Tensor] = None,
env_idx: int = 0,
):
w_inv_pose = self._get_obstacle_poses(w_obj_pose, obj_w_pose)
if name is not None:
obs_idx = self.get_mesh_idx(name, env_idx)
self._mesh_tensor_list[1][env_idx, obs_idx, :7] = w_inv_pose.get_pose_vector()
elif env_obj_idx is not None:
self._mesh_tensor_list[1][env_idx, env_obj_idx, :7] = w_inv_pose.get_pose_vector()
else:
raise ValueError("name or env_obj_idx needs to be given to update mesh pose")
def update_all_mesh_pose(
self,
w_obj_pose: Optional[Pose] = None,
obj_w_pose: Optional[Pose] = None,
name: Optional[List[str]] = None,
env_obj_idx: Optional[torch.Tensor] = None,
env_idx: int = 0,
):
"""Update poses for a list of meshes in the same environment
Args:
w_obj_pose (Optional[Pose], optional): _description_. Defaults to None.
obj_w_pose (Optional[Pose], optional): _description_. Defaults to None.
name (Optional[List[str]], optional): _description_. Defaults to None.
env_obj_idx (Optional[torch.Tensor], optional): _description_. Defaults to None.
env_idx (int, optional): _description_. Defaults to 0.
"""
w_inv_pose = self._get_obstacle_poses(w_obj_pose, obj_w_pose)
raise NotImplementedError
def update_mesh_pose_env(
self,
w_obj_pose: Optional[Pose] = None,
obj_w_pose: Optional[Pose] = None,
name: Optional[str] = None,
env_obj_idx: Optional[torch.Tensor] = None,
env_idx: List[int] = [0],
):
"""Update pose of a single object in a list of environments
Args:
w_obj_pose (Optional[Pose], optional): _description_. Defaults to None.
obj_w_pose (Optional[Pose], optional): _description_. Defaults to None.
name (Optional[List[str]], optional): _description_. Defaults to None.
env_obj_idx (Optional[torch.Tensor], optional): _description_. Defaults to None.
env_idx (List[int], optional): _description_. Defaults to [0].
"""
w_inv_pose = self._get_obstacle_poses(w_obj_pose, obj_w_pose)
# collect index of mesh across environments:
# index_tensor = torch.zeros((1, len(env_idx)), dtype=torch.long, device=self.tensor_args.device)
# for i, e in enumerate[env_idx]:
# index_tensor[0,i] = self.get_mesh_idx(name, e)
raise NotImplementedError
# self._mesh_tensor_list[1][env_idx, obj_idx]
def update_mesh_from_warp(
self,
warp_mesh_idx: int,
w_obj_pose: Optional[Pose] = None,
obj_w_pose: Optional[Pose] = None,
obj_idx: int = 0,
env_idx: int = 0,
name: Optional[str] = None,
):
if name is not None:
obj_idx = self.get_mesh_idx(name, env_idx)
if obj_idx >= self._mesh_tensor_list[0][env_idx].shape[0]:
log_error("Out of cache memory")
w_inv_pose = self._get_obstacle_poses(w_obj_pose, obj_w_pose)
self._mesh_tensor_list[0][env_idx, obj_idx] = warp_mesh_idx
self._mesh_tensor_list[1][env_idx, obj_idx] = w_inv_pose
self._mesh_tensor_list[2][env_idx, obj_idx] = 1
self._env_mesh_names[env_idx][obj_idx] = name
if self._env_n_mesh[env_idx] <= obj_idx:
self._env_n_mesh[env_idx] = obj_idx + 1
def update_obstacle_pose(
self,
name: str,
w_obj_pose: Pose,
env_idx: int = 0,
):
if self._env_mesh_names is not None and name in self._env_mesh_names[env_idx]:
self.update_mesh_pose(name=name, w_obj_pose=w_obj_pose, env_idx=env_idx)
elif self._env_obbs_names is not None and name in self._env_obbs_names[env_idx]:
self.update_obb_pose(name=name, w_obj_pose=w_obj_pose, env_idx=env_idx)
else:
log_error("obstacle not found in OBB world model: " + name)
def enable_obstacle(
self,
name: str,
enable: bool = True,
env_idx: int = 0,
):
if self._env_mesh_names is not None and name in self._env_mesh_names[env_idx]:
self.enable_mesh(enable, name, None, env_idx)
elif self._env_obbs_names is not None and name in self._env_obbs_names[env_idx]:
self.enable_obb(enable, name, None, env_idx)
else:
log_error("Obstacle not found in world model: " + name)
self.world_model.objects
def enable_mesh(
self,
enable: bool = True,
name: Optional[str] = None,
env_mesh_idx: Optional[torch.Tensor] = None,
env_idx: int = 0,
):
"""Update obstacle dimensions
Args:
obj_dims (torch.Tensor): [dim.x,dim.y, dim.z], give as [b,3]
obj_idx (torch.Tensor or int):
"""
if env_mesh_idx is not None:
self._mesh_tensor_list[2][env_mesh_idx] = int(enable) # enable == 1
else:
# find index of given name:
obs_idx = self.get_mesh_idx(name, env_idx)
self._mesh_tensor_list[2][env_idx, obs_idx] = int(enable)
def _get_sdf(
self,
query_spheres,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
env_query_idx=None,
return_loss=False,
):
d = SdfMeshWarpPy.apply(
query_spheres,
collision_query_buffer.mesh_collision_buffer.distance_buffer,
collision_query_buffer.mesh_collision_buffer.grad_distance_buffer,
collision_query_buffer.mesh_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self._mesh_tensor_list[0],
self._mesh_tensor_list[1],
self._mesh_tensor_list[2],
self._env_n_mesh,
self.max_distance,
env_query_idx,
return_loss,
)
return d
def _get_swept_sdf(
self,
query_spheres,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
speed_dt: torch.Tensor,
sweep_steps: int,
enable_speed_metric=False,
env_query_idx=None,
return_loss: bool = False,
):
d = SweptSdfMeshWarpPy.apply(
query_spheres,
collision_query_buffer.mesh_collision_buffer.distance_buffer,
collision_query_buffer.mesh_collision_buffer.grad_distance_buffer,
collision_query_buffer.mesh_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
speed_dt,
self._mesh_tensor_list[0],
self._mesh_tensor_list[1],
self._mesh_tensor_list[2],
self._env_n_mesh,
sweep_steps,
enable_speed_metric,
self.max_distance,
env_query_idx,
return_loss,
)
return d
def get_sphere_distance(
self,
query_sphere: torch.Tensor,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
):
# TODO: if no mesh object exist, call primitive
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
return super().get_sphere_distance(
query_sphere,
collision_query_buffer,
weight=weight,
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
)
d = self._get_sdf(
query_sphere,
collision_query_buffer,
weight=weight,
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
)
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
return d
d_prim = super().get_sphere_distance(
query_sphere,
collision_query_buffer,
weight=weight,
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
)
d_val = d.view(d_prim.shape) + d_prim
return d_val
def get_sphere_collision(
self,
query_sphere,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
env_query_idx=None,
return_loss=False,
):
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
return super().get_sphere_collision(
query_sphere,
collision_query_buffer,
weight=weight,
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
)
d = self._get_sdf(
query_sphere,
collision_query_buffer,
weight=weight,
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
)
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
return d
d_prim = super().get_sphere_collision(
query_sphere,
collision_query_buffer,
weight,
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
)
d_val = d.view(d_prim.shape) + d_prim
return d_val
def get_swept_sphere_distance(
self,
query_sphere,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
speed_dt: torch.Tensor,
sweep_steps: int,
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
):
# log_warn("Swept: Mesh + Primitive Collision Checking is experimental")
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
return super().get_swept_sphere_distance(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
)
d = self._get_swept_sdf(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
)
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
return d
d_prim = super().get_swept_sphere_distance(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
)
d_val = d.view(d_prim.shape) + d_prim
return d_val
def get_swept_sphere_collision(
self,
query_sphere,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
sweep_steps,
activation_distance: torch.Tensor,
speed_dt: torch.Tensor,
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
):
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
return super().get_swept_sphere_collision(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
)
d = self._get_swept_sdf(
query_sphere,
collision_query_buffer,
weight=weight,
activation_distance=activation_distance,
speed_dt=speed_dt,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
)
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
return d
d_prim = super().get_swept_sphere_collision(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
)
d_val = d.view(d_prim.shape) + d_prim
return d_val
def clear_cache(self):
self._wp_mesh_cache = {}
if self._mesh_tensor_list is not None:
self._mesh_tensor_list[2][:] = 0
super().clear_cache()