Significantly improved convergence for mesh and cuboid, new ESDF collision.

This commit is contained in:
Balakumar Sundaralingam
2024-03-18 11:19:48 -07:00
parent 286b3820a5
commit b1f63e8778
100 changed files with 7587 additions and 2589 deletions

View File

@@ -56,8 +56,10 @@ def _get_version():
# `importlib_metadata` is the back ported library for older versions of python.
# Third Party
from importlib_metadata import version
return version("nvidia_curobo")
try:
return version("nvidia_curobo")
except:
return "v0.7.0-no-tag"
# Set `__version__` attribute

View File

@@ -280,6 +280,7 @@
<origin rpy="0 0 0" xyz="0 -0.04 0.0584"/>
<axis xyz="0 -1 0"/>
<limit effort="20" lower="0.0" upper="0.04" velocity="0.2"/>
<mimic joint="panda_finger_joint1"/>
</joint>
<link name="right_gripper">
<inertial>

View File

@@ -13,7 +13,7 @@ robot_cfg:
kinematics:
use_usd_kinematics: False
isaac_usd_path: "/Isaac/Robots/Franka/franka.usd"
usd_path: "robot/franka/franka_panda.usda"
usd_path: "robot/non_shipping/franka/franka_panda_meters.usda"
usd_robot_root: "/panda"
usd_flip_joints: ["panda_joint1","panda_joint2","panda_joint3","panda_joint4", "panda_joint5",
"panda_joint6","panda_joint7","panda_finger_joint1", "panda_finger_joint2"]

View File

@@ -39,7 +39,13 @@ robot_cfg:
}
urdf_path: robot/iiwa_allegro_description/iiwa.urdf
asset_root_path: robot/iiwa_allegro_description
mesh_link_names:
- iiwa7_link_1
- iiwa7_link_2
- iiwa7_link_3
- iiwa7_link_4
- iiwa7_link_5
- iiwa7_link_6
cspace:
joint_names:
[

View File

@@ -30,6 +30,25 @@ robot_cfg:
- ring_link_3
- thumb_link_2
- thumb_link_3
mesh_link_names:
- iiwa7_link_1
- iiwa7_link_2
- iiwa7_link_3
- iiwa7_link_4
- iiwa7_link_5
- iiwa7_link_6
- palm_link
- index_link_1
- index_link_2
- index_link_3
- middle_link_1
- middle_link_2
- middle_link_3
- ring_link_1
- ring_link_2
- ring_link_3
- thumb_link_2
- thumb_link_3
collision_sphere_buffer: 0.005
collision_spheres: spheres/iiwa_allegro.yml
ee_link: palm_link

View File

@@ -134,11 +134,11 @@ collision_spheres:
"radius": 0.022
panda_leftfinger:
- "center": [0.0, 0.01, 0.043]
"radius": 0.011 # 25
"radius": 0.011 #0.025 # 0.011
- "center": [0.0, 0.02, 0.015]
"radius": 0.011 # 25
"radius": 0.011 #0.025 # 0.011
panda_rightfinger:
- "center": [0.0, -0.01, 0.043]
"radius": 0.011 #25
"radius": 0.011 #0.025 #0.011
- "center": [0.0, -0.02, 0.015]
"radius": 0.011 #25
"radius": 0.011 #0.025 #0.011

View File

@@ -58,7 +58,7 @@ collision_spheres:
radius: 0.07
tool0:
- center: [0, 0, 0.12]
radius: 0.05
radius: -0.01
camera_mount:
- center: [0, 0.11, -0.01]
radius: 0.06

View File

@@ -38,7 +38,7 @@ robot_cfg:
'wrist_1_link': 0,
'wrist_2_link': 0,
'wrist_3_link' : 0,
'tool0': 0,
'tool0': 0.05,
}
mesh_link_names: [ 'shoulder_link','upper_arm_link', 'forearm_link', 'wrist_1_link', 'wrist_2_link' ,'wrist_3_link' ]
lock_joints: null

View File

@@ -92,7 +92,7 @@ robot_cfg:
"radius": 0.043
tool0:
- "center": [0.001, 0.001, 0.05]
"radius": 0.05
"radius": -0.01 #0.05
collision_sphere_buffer: 0.005
@@ -109,6 +109,7 @@ robot_cfg:
'wrist_1_link': 0,
'wrist_2_link': 0,
'wrist_3_link' : 0,
'tool0': 0.025,
}
use_global_cumul: True

View File

@@ -55,17 +55,17 @@ cost:
bound_cfg:
weight: [5000.0, 50000.0, 50000.0, 50000.0] # needs to be 3 values
smooth_weight: [0.0,5000.0, 50.0, 0.0] # [vel, acc, jerk,]
smooth_weight: [0.0,5000.0, 50.0, 0.0] #[0.0,5000.0, 50.0, 0.0] # [vel, acc, jerk,]
run_weight_velocity: 0.0
run_weight_acceleration: 1.0
run_weight_jerk: 1.0
activation_distance: [0.1,0.1,0.1,10.0] # for position, velocity, acceleration and jerk
activation_distance: [0.1,0.1,0.1,0.1] # for position, velocity, acceleration and jerk
null_space_weight: [0.0]
primitive_collision_cfg:
weight: 1000000.0
weight: 1000000.0 #1000000.0 1000000
use_sweep: True
sweep_steps: 6
sweep_steps: 4
classify: False
use_sweep_kernel: True
use_speed_metric: True
@@ -79,7 +79,7 @@ cost:
lbfgs:
n_iters: 400 # 400
n_iters: 300 # 400
inner_iters: 25
cold_start_n_iters: null
min_iters: 25
@@ -89,7 +89,7 @@ lbfgs:
cost_delta_threshold: 1.0
cost_relative_threshold: 0.999 #0.999
epsilon: 0.01
history: 15 #15
history: 27 #15
use_cuda_graph: True
n_problems: 1
store_debug: False

View File

@@ -47,7 +47,7 @@ cost:
activation_distance: [0.1]
null_space_weight: [1.0]
primitive_collision_cfg:
weight: 50000.0
weight: 5000.0
use_sweep: False
classify: False
activation_distance: 0.01
@@ -57,11 +57,11 @@ cost:
lbfgs:
n_iters: 100 #60
inner_iters: 25
n_iters: 80 #60
inner_iters: 20
cold_start_n_iters: null
min_iters: 20
line_search_scale: [0.01, 0.3, 0.7, 1.0]
line_search_scale: [0.1, 0.3, 0.7, 1.0]
fixed_iters: True
cost_convergence: 1e-7
cost_delta_threshold: 1e-6 #0.0001

View File

@@ -40,7 +40,7 @@ cost:
link_pose_cfg:
vec_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] # orientation, position for all timesteps
run_vec_weight: [1.00,1.00,1.00,1.0,1.0,1.0] # running weight orientation, position
run_vec_weight: [0.00,0.00,0.00,0.0,0.0,0.0] # running weight orientation, position
weight: [2000,50000.0,30,50] #[150.0, 2000.0, 30, 40]
vec_convergence: [0.0,0.0] # orientation, position, orientation metric activation, position metric activation
terminal: True
@@ -54,19 +54,17 @@ cost:
bound_cfg:
weight: [5000.0, 50000.0, 50000.0,50000.0] # needs to be 3 values
#weight: [000.0, 000.0, 000.0,000.0]
smooth_weight: [0.0,10000.0,10.0, 0.0] # [vel, acc, jerk, alpha_v-not-used]
#smooth_weight: [0.0,0000.0,0.0, 0.0] # [vel, acc, jerk, alpha_v-not-used]
run_weight_velocity: 0.00
run_weight_acceleration: 1.0
run_weight_jerk: 1.0
activation_distance: [0.1,0.1,0.1,10.0] # for position, velocity, acceleration and jerk
activation_distance: [0.1,0.1,0.1,0.1] # for position, velocity, acceleration and jerk
null_space_weight: [0.0]
primitive_collision_cfg:
weight: 100000.0
use_sweep: True
sweep_steps: 6
sweep_steps: 4
classify: False
use_sweep_kernel: True
use_speed_metric: True
@@ -81,11 +79,11 @@ cost:
lbfgs:
n_iters: 125 #175
n_iters: 100 #175
inner_iters: 25
cold_start_n_iters: null
min_iters: 25
line_search_scale: [0.01,0.3,0.7,1.0] #[0.01,0.2, 0.3,0.5,0.7,0.9, 1.0] #
line_search_scale: [0.01,0.3,0.7,1.0]
fixed_iters: True
cost_convergence: 0.01
cost_delta_threshold: 2000.0

View File

@@ -33,7 +33,7 @@ graph:
sample_pts: 1500
node_similarity_distance: 0.1
rejection_ratio: 20
rejection_ratio: 10
k_nn: 15
max_buffer: 10000
max_cg_buffer: 1000

View File

@@ -34,7 +34,7 @@ cost:
run_weight: 1.0
link_pose_cfg:
vec_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
weight: [30, 50, 10, 10] #[20.0, 100.0]
weight: [0, 50, 10, 10] #[20.0, 100.0]
vec_convergence: [0.00, 0.000] # orientation, position
terminal: False
use_metric: True
@@ -67,7 +67,7 @@ mppi:
cov_type : "DIAG_A" #
kappa : 0.01
null_act_frac : 0.0
sample_mode : 'BEST'
sample_mode : 'MEAN'
base_action : 'REPEAT'
squash_fn : 'CLAMP'
n_problems : 1

View File

@@ -89,7 +89,7 @@ mppi:
sample_mode : 'BEST'
base_action : 'REPEAT'
squash_fn : 'CLAMP'
n_problems : 1
n_problems : 1
use_cuda_graph : True
seed : 0
store_debug : False

View File

@@ -39,7 +39,7 @@ cost:
link_pose_cfg:
vec_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
run_vec_weight: [0.00,0.00,0.00,0.0,0.0,0.0] # running weight orientation, position
run_vec_weight: [0.0,0.0,0.0,0.0,0.0,0.0] # running weight
weight: [0.0, 5000.0, 40, 40]
vec_convergence: [0.0,0.0,1000.0,1000.0]
terminal: True
@@ -63,11 +63,11 @@ cost:
primitive_collision_cfg:
weight: 5000.0
use_sweep: True
use_sweep: False
classify: False
sweep_steps: 4
use_sweep_kernel: True
use_speed_metric: True
use_sweep_kernel: False
use_speed_metric: False
speed_dt: 0.01 # used only for speed metric
activation_distance: 0.025
@@ -92,7 +92,7 @@ mppi:
cov_type : "DIAG_A" #
kappa : 0.001
null_act_frac : 0.0
sample_mode : 'BEST'
sample_mode : 'MEAN'
base_action : 'REPEAT'
squash_fn : 'CLAMP'
n_problems : 1

View File

@@ -180,7 +180,7 @@ class CudaRobotModel(CudaRobotModelConfig):
self._batch_robot_spheres = torch.zeros(
(self._batch_size, self.kinematics_config.total_spheres, 4),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
dtype=self.tensor_args.collision_geometry_dtype,
)
self._grad_out_q = torch.zeros(
(self._batch_size, self.get_dof()),

View File

@@ -23,6 +23,7 @@ from curobo.types.base import TensorDeviceType
from curobo.types.math import Pose
from curobo.types.state import JointState
from curobo.types.tensor import T_DOF
from curobo.util.logger import log_error
from curobo.util.tensor_util import clone_if_not_none, copy_if_not_none
@@ -138,6 +139,13 @@ class CSpaceConfig:
self.acceleration_scale = self.tensor_args.to_device(self.acceleration_scale)
if isinstance(self.jerk_scale, List):
self.jerk_scale = self.tensor_args.to_device(self.jerk_scale)
# check shapes:
if self.retract_config is not None:
dof = self.retract_config.shape
if self.cspace_distance_weight is not None and self.cspace_distance_weight.shape != dof:
log_error("cspace_distance_weight shape does not match retract_config")
if self.null_space_weight is not None and self.null_space_weight.shape != dof:
log_error("null_space_weight shape does not match retract_config")
def inplace_reindex(self, joint_names: List[str]):
new_index = [self.joint_names.index(j) for j in joint_names]
@@ -207,8 +215,8 @@ class CSpaceConfig:
):
retract_config = ((joint_position_upper + joint_position_lower) / 2).flatten()
n_dof = retract_config.shape[-1]
null_space_weight = torch.ones(n_dof, **vars(tensor_args))
cspace_distance_weight = torch.ones(n_dof, **vars(tensor_args))
null_space_weight = torch.ones(n_dof, **(tensor_args.as_torch_dict()))
cspace_distance_weight = torch.ones(n_dof, **(tensor_args.as_torch_dict()))
return CSpaceConfig(
joint_names,
retract_config,
@@ -289,8 +297,8 @@ class KinematicsTensorConfig:
retract_config = (
(self.joint_limits.position[1] + self.joint_limits.position[0]) / 2
).flatten()
null_space_weight = torch.ones(self.n_dof, **vars(self.tensor_args))
cspace_distance_weight = torch.ones(self.n_dof, **vars(self.tensor_args))
null_space_weight = torch.ones(self.n_dof, **(self.tensor_args.as_torch_dict()))
cspace_distance_weight = torch.ones(self.n_dof, **(self.tensor_args.as_torch_dict()))
joint_names = self.joint_names
self.cspace = CSpaceConfig(
joint_names,

View File

@@ -175,7 +175,11 @@ class UrdfKinematicsParser(KinematicsParser):
return txt
def get_link_mesh(self, link_name):
m = self._robot.link_map[link_name].visuals[0].geometry.mesh
link_data = self._robot.link_map[link_name]
if len(link_data.visuals) == 0:
log_error(link_name + " not found in urdf, remove from mesh_link_names")
m = link_data.visuals[0].geometry.mesh
mesh_pose = self._robot.link_map[link_name].visuals[0].origin
# read visual material:
if mesh_pose is None:

View File

@@ -57,7 +57,8 @@ std::vector<torch::Tensor>swept_sphere_obb_clpt(
const bool enable_speed_metric,
const bool transform_back,
const bool compute_distance,
const bool use_batch_env);
const bool use_batch_env,
const bool sum_collisions);
std::vector<torch::Tensor>
sphere_obb_clpt(const torch::Tensor sphere_position, // batch_size, 4
@@ -66,6 +67,7 @@ sphere_obb_clpt(const torch::Tensor sphere_position, // batch_size, 4
torch::Tensor sparsity_idx,
const torch::Tensor weight,
const torch::Tensor activation_distance,
const torch::Tensor max_distance,
const torch::Tensor obb_accel, // n_boxes, 4, 4
const torch::Tensor obb_bounds, // n_boxes, 3
const torch::Tensor obb_pose, // n_boxes, 4, 4
@@ -78,8 +80,52 @@ sphere_obb_clpt(const torch::Tensor sphere_position, // batch_size, 4
const int n_spheres,
const bool transform_back,
const bool compute_distance,
const bool use_batch_env);
const bool use_batch_env,
const bool sum_collisions,
const bool compute_esdf);
std::vector<torch::Tensor>
sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
torch::Tensor distance,
torch::Tensor closest_point, // batch size, 3
torch::Tensor sparsity_idx, const torch::Tensor weight,
const torch::Tensor activation_distance,
const torch::Tensor max_distance,
const torch::Tensor grid_features, // n_boxes, 4, 4
const torch::Tensor grid_params, // n_boxes, 3
const torch::Tensor grid_pose, // n_boxes, 4, 4
const torch::Tensor grid_enable, // n_boxes, 4, 4
const torch::Tensor n_env_grid,
const torch::Tensor env_query_idx, // n_boxes, 4, 4
const int max_nobs, const int batch_size, const int horizon,
const int n_spheres, const bool transform_back,
const bool compute_distance, const bool use_batch_env,
const bool sum_collisions,
const bool compute_esdf);
std::vector<torch::Tensor>
swept_sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
torch::Tensor distance,
torch::Tensor closest_point, // batch size, 3
torch::Tensor sparsity_idx, const torch::Tensor weight,
const torch::Tensor activation_distance,
const torch::Tensor max_distance,
const torch::Tensor speed_dt,
const torch::Tensor grid_features, // n_boxes, 4, 4
const torch::Tensor grid_params, // n_boxes, 3
const torch::Tensor grid_pose, // n_boxes, 4, 4
const torch::Tensor grid_enable, // n_boxes, 4, 4
const torch::Tensor n_env_grid,
const torch::Tensor env_query_idx, // n_boxes, 4, 4
const int max_nobs,
const int batch_size,
const int horizon,
const int n_spheres,
const int sweep_steps,
const bool enable_speed_metric,
const bool transform_back,
const bool compute_distance,
const bool use_batch_env,
const bool sum_collisions);
std::vector<torch::Tensor>pose_distance(
torch::Tensor out_distance,
torch::Tensor out_position_distance,
@@ -159,11 +205,11 @@ std::vector<torch::Tensor>self_collision_distance_wrapper(
std::vector<torch::Tensor>sphere_obb_clpt_wrapper(
const torch::Tensor sphere_position, // batch_size, 4
torch::Tensor distance,
torch::Tensor closest_point, // batch size, 3
torch::Tensor sparsity_idx, const torch::Tensor weight,
const torch::Tensor activation_distance,
const torch::Tensor max_distance,
const torch::Tensor obb_accel, // n_boxes, 4, 4
const torch::Tensor obb_bounds, // n_boxes, 3
const torch::Tensor obb_pose, // n_boxes, 4, 4
@@ -171,8 +217,10 @@ std::vector<torch::Tensor>sphere_obb_clpt_wrapper(
const torch::Tensor n_env_obb, // n_boxes, 4, 4
const torch::Tensor env_query_idx, // n_boxes, 4, 4
const int max_nobs, const int batch_size, const int horizon,
const int n_spheres, const bool transform_back, const bool compute_distance,
const bool use_batch_env)
const int n_spheres,
const bool transform_back, const bool compute_distance,
const bool use_batch_env, const bool sum_collisions = true,
const bool compute_esdf = false)
{
const at::cuda::OptionalCUDAGuard guard(sphere_position.device());
@@ -185,9 +233,9 @@ std::vector<torch::Tensor>sphere_obb_clpt_wrapper(
CHECK_INPUT(obb_accel);
return sphere_obb_clpt(
sphere_position, distance, closest_point, sparsity_idx, weight,
activation_distance, obb_accel, obb_bounds, obb_pose, obb_enable,
activation_distance, max_distance, obb_accel, obb_bounds, obb_pose, obb_enable,
n_env_obb, env_query_idx, max_nobs, batch_size, horizon, n_spheres,
transform_back, compute_distance, use_batch_env);
transform_back, compute_distance, use_batch_env, sum_collisions, compute_esdf);
}
std::vector<torch::Tensor>swept_sphere_obb_clpt_wrapper(
@@ -205,7 +253,7 @@ std::vector<torch::Tensor>swept_sphere_obb_clpt_wrapper(
const int max_nobs, const int batch_size, const int horizon,
const int n_spheres, const int sweep_steps, const bool enable_speed_metric,
const bool transform_back, const bool compute_distance,
const bool use_batch_env)
const bool use_batch_env, const bool sum_collisions = true)
{
const at::cuda::OptionalCUDAGuard guard(sphere_position.device());
@@ -218,7 +266,37 @@ std::vector<torch::Tensor>swept_sphere_obb_clpt_wrapper(
distance, closest_point, sparsity_idx, weight, activation_distance,
speed_dt, obb_accel, obb_bounds, obb_pose, obb_enable, n_env_obb,
env_query_idx, max_nobs, batch_size, horizon, n_spheres, sweep_steps,
enable_speed_metric, transform_back, compute_distance, use_batch_env);
enable_speed_metric, transform_back, compute_distance, use_batch_env, sum_collisions);
}
std::vector<torch::Tensor>
sphere_voxel_clpt_wrapper(const torch::Tensor sphere_position, // batch_size, 3
torch::Tensor distance,
torch::Tensor closest_point, // batch size, 3
torch::Tensor sparsity_idx, const torch::Tensor weight,
const torch::Tensor activation_distance,
const torch::Tensor max_distance,
const torch::Tensor grid_features, // n_boxes, 4, 4
const torch::Tensor grid_params, // n_boxes, 3
const torch::Tensor grid_pose, // n_boxes, 4, 4
const torch::Tensor grid_enable, // n_boxes, 4, 4
const torch::Tensor n_env_grid,
const torch::Tensor env_query_idx, // n_boxes, 4, 4
const int max_ngrid, const int batch_size, const int horizon,
const int n_spheres, const bool transform_back,
const bool compute_distance, const bool use_batch_env,
const bool sum_collisions,
const bool compute_esdf)
{
const at::cuda::OptionalCUDAGuard guard(sphere_position.device());
CHECK_INPUT(distance);
CHECK_INPUT(closest_point);
CHECK_INPUT(sphere_position);
return sphere_voxel_clpt(sphere_position, distance, closest_point, sparsity_idx, weight,
activation_distance, max_distance, grid_features, grid_params,
grid_pose, grid_enable, n_env_grid, env_query_idx, max_ngrid, batch_size, horizon, n_spheres,
transform_back, compute_distance, use_batch_env, sum_collisions, compute_esdf);
}
std::vector<torch::Tensor>pose_distance_wrapper(
@@ -297,6 +375,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"Closest Point OBB(curobolib)");
m.def("swept_closest_point", &swept_sphere_obb_clpt_wrapper,
"Swept Closest Point OBB(curobolib)");
m.def("closest_point_voxel", &sphere_voxel_clpt_wrapper,
"Closest Point Voxel(curobolib)");
m.def("swept_closest_point_voxel", &swept_sphere_voxel_clpt,
"Swpet Closest Point Voxel(curobolib)");
m.def("self_collision_distance", &self_collision_distance_wrapper,
"Self Collision Distance (curobolib)");

View File

@@ -14,38 +14,6 @@
#include <c10/cuda/CUDAGuard.h>
// CUDA forward declarations
std::vector<torch::Tensor>reduce_cuda(torch::Tensor vec,
torch::Tensor vec2,
torch::Tensor rho_buffer,
torch::Tensor sum,
const int batch_size,
const int v_dim,
const int m);
std::vector<torch::Tensor>
lbfgs_step_cuda(torch::Tensor step_vec,
torch::Tensor rho_buffer,
torch::Tensor y_buffer,
torch::Tensor s_buffer,
torch::Tensor grad_q,
const float epsilon,
const int batch_size,
const int m,
const int v_dim);
std::vector<torch::Tensor>
lbfgs_update_cuda(torch::Tensor rho_buffer,
torch::Tensor y_buffer,
torch::Tensor s_buffer,
torch::Tensor q,
torch::Tensor grad_q,
torch::Tensor x_0,
torch::Tensor grad_0,
const int batch_size,
const int m,
const int v_dim);
std::vector<torch::Tensor>
lbfgs_cuda_fuse(torch::Tensor step_vec,
torch::Tensor rho_buffer,
@@ -59,7 +27,8 @@ lbfgs_cuda_fuse(torch::Tensor step_vec,
const int batch_size,
const int m,
const int v_dim,
const bool stable_mode);
const bool stable_mode,
const bool use_shared_buffers);
// C++ interface
@@ -71,58 +40,12 @@ lbfgs_cuda_fuse(torch::Tensor step_vec,
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor>
lbfgs_step_call(torch::Tensor step_vec, torch::Tensor rho_buffer,
torch::Tensor y_buffer, torch::Tensor s_buffer,
torch::Tensor grad_q, const float epsilon, const int batch_size,
const int m, const int v_dim)
{
CHECK_INPUT(step_vec);
CHECK_INPUT(rho_buffer);
CHECK_INPUT(y_buffer);
CHECK_INPUT(s_buffer);
CHECK_INPUT(grad_q);
const at::cuda::OptionalCUDAGuard guard(grad_q.device());
return lbfgs_step_cuda(step_vec, rho_buffer, y_buffer, s_buffer, grad_q,
epsilon, batch_size, m, v_dim);
}
std::vector<torch::Tensor>
lbfgs_update_call(torch::Tensor rho_buffer, torch::Tensor y_buffer,
torch::Tensor s_buffer, torch::Tensor q, torch::Tensor grad_q,
torch::Tensor x_0, torch::Tensor grad_0, const int batch_size,
const int m, const int v_dim)
{
CHECK_INPUT(rho_buffer);
CHECK_INPUT(y_buffer);
CHECK_INPUT(s_buffer);
CHECK_INPUT(grad_q);
CHECK_INPUT(x_0);
CHECK_INPUT(grad_0);
CHECK_INPUT(q);
const at::cuda::OptionalCUDAGuard guard(grad_q.device());
return lbfgs_update_cuda(rho_buffer, y_buffer, s_buffer, q, grad_q, x_0,
grad_0, batch_size, m, v_dim);
}
std::vector<torch::Tensor>
reduce_cuda_call(torch::Tensor vec, torch::Tensor vec2,
torch::Tensor rho_buffer, torch::Tensor sum,
const int batch_size, const int v_dim, const int m)
{
const at::cuda::OptionalCUDAGuard guard(sum.device());
return reduce_cuda(vec, vec2, rho_buffer, sum, batch_size, v_dim, m);
}
std::vector<torch::Tensor>
lbfgs_call(torch::Tensor step_vec, torch::Tensor rho_buffer,
torch::Tensor y_buffer, torch::Tensor s_buffer, torch::Tensor q,
torch::Tensor grad_q, torch::Tensor x_0, torch::Tensor grad_0,
const float epsilon, const int batch_size, const int m,
const int v_dim, const bool stable_mode)
const int v_dim, const bool stable_mode, const bool use_shared_buffers)
{
CHECK_INPUT(step_vec);
CHECK_INPUT(rho_buffer);
@@ -136,13 +59,11 @@ lbfgs_call(torch::Tensor step_vec, torch::Tensor rho_buffer,
return lbfgs_cuda_fuse(step_vec, rho_buffer, y_buffer, s_buffer, q, grad_q,
x_0, grad_0, epsilon, batch_size, m, v_dim,
stable_mode);
stable_mode, use_shared_buffers);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("step", &lbfgs_step_call, "L-BFGS step (CUDA)");
m.def("update", &lbfgs_update_call, "L-BFGS Update (CUDA)");
m.def("forward", &lbfgs_call, "L-BFGS Update + Step (CUDA)");
m.def("debug_reduce", &reduce_cuda_call, "L-BFGS Debug");
}

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,6 @@
* its affiliates is strictly prohibited.
*/
#pragma once
#include <cuda.h>
#include <torch/extension.h>
#include <vector>

View File

@@ -185,11 +185,11 @@ namespace Curobo
if (coll_matrix[i * nspheres + j] == 1)
{
float4 sph1 = __rs_shared[i];
if ((sph1.w <= 0.0) || (sph2.w <= 0.0))
{
continue;
}
//
//if ((sph1.w <= 0.0) || (sph2.w <= 0.0))
//{
// continue;
//}
float r_diff = sph1.w + sph2.w;
float d = sqrt((sph1.x - sph2.x) * (sph1.x - sph2.x) +
(sph1.y - sph2.y) * (sph1.y - sph2.y) +
@@ -380,10 +380,10 @@ namespace Curobo
float4 sph1 = __rs_shared[NBPB * i + l];
float4 sph2 = __rs_shared[NBPB * j + l];
if ((sph1.w <= 0.0) || (sph2.w <= 0.0))
{
continue;
}
//if ((sph1.w <= 0.0) || (sph2.w <= 0.0))
//{
// continue;
//}
float r_diff =
sph1.w + sph2.w; // sum of two radii, radii include respective offsets
float d = sqrt((sph1.x - sph2.x) * (sph1.x - sph2.x) +

File diff suppressed because it is too large Load Diff

View File

@@ -342,10 +342,8 @@ namespace Curobo
float out_pos = 0.0, out_vel = 0.0, out_acc = 0.0, out_jerk = 0.0;
float st_pos = 0.0, st_vel = 0.0, st_acc = 0.0;
const int b_addrs = b_idx * horizon * dof;
const int b_addrs_action = b_idx * (horizon - 4) * dof;
float in_pos[5]; // create a 5 value scalar
const float acc_scale = 1.0;
#pragma unroll 5

View File

@@ -13,6 +13,7 @@ import torch
# CuRobo
from curobo.util.logger import log_warn
from curobo.util.torch_utils import get_torch_jit_decorator
try:
# CuRobo
@@ -235,7 +236,7 @@ def get_pose_distance_backward(
return r[0], r[1]
@torch.jit.script
@get_torch_jit_decorator()
def backward_PoseError_jit(grad_g_dist, grad_out_distance, weight, g_vec):
grad_vec = grad_g_dist + (grad_out_distance * weight)
grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec
@@ -243,7 +244,7 @@ def backward_PoseError_jit(grad_g_dist, grad_out_distance, weight, g_vec):
# full method:
@torch.jit.script
@get_torch_jit_decorator()
def backward_full_PoseError_jit(
grad_out_distance, grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
):
@@ -570,6 +571,7 @@ class SdfSphereOBB(torch.autograd.Function):
sparsity_idx,
weight,
activation_distance,
max_distance,
box_accel,
box_dims,
box_pose,
@@ -584,6 +586,8 @@ class SdfSphereOBB(torch.autograd.Function):
compute_distance,
use_batch_env,
return_loss: bool = False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
r = geom_cu.closest_point(
query_sphere,
@@ -592,6 +596,7 @@ class SdfSphereOBB(torch.autograd.Function):
sparsity_idx,
weight,
activation_distance,
max_distance,
box_accel,
box_dims,
box_pose,
@@ -605,8 +610,11 @@ class SdfSphereOBB(torch.autograd.Function):
transform_back,
compute_distance,
use_batch_env,
sum_collisions,
compute_esdf,
)
# r[1][r[1]!=r[1]] = 0.0
ctx.compute_esdf = compute_esdf
ctx.return_loss = return_loss
ctx.save_for_backward(r[1])
return r[0]
@@ -615,6 +623,8 @@ class SdfSphereOBB(torch.autograd.Function):
def backward(ctx, grad_output):
grad_pt = None
if ctx.needs_input_grad[0]:
# if ctx.compute_esdf:
# raise NotImplementedError("Gradients not implemented for compute_esdf=True")
(r,) = ctx.saved_tensors
if ctx.return_loss:
r = r * grad_output.unsqueeze(-1)
@@ -640,6 +650,9 @@ class SdfSphereOBB(torch.autograd.Function):
None,
None,
None,
None,
None,
None,
)
@@ -670,6 +683,7 @@ class SdfSweptSphereOBB(torch.autograd.Function):
compute_distance,
use_batch_env,
return_loss: bool = False,
sum_collisions: bool = True,
):
r = geom_cu.swept_closest_point(
query_sphere,
@@ -694,6 +708,7 @@ class SdfSweptSphereOBB(torch.autograd.Function):
transform_back,
compute_distance,
use_batch_env,
sum_collisions,
)
ctx.return_loss = return_loss
ctx.save_for_backward(
@@ -733,4 +748,200 @@ class SdfSweptSphereOBB(torch.autograd.Function):
None,
None,
None,
None,
)
class SdfSphereVoxel(torch.autograd.Function):
@staticmethod
def forward(
ctx,
query_sphere,
out_buffer,
grad_out_buffer,
sparsity_idx,
weight,
activation_distance,
max_distance,
grid_features,
grid_params,
grid_pose,
grid_enable,
n_env_grid,
env_query_idx,
max_nobs,
batch_size,
horizon,
n_spheres,
transform_back,
compute_distance,
use_batch_env,
return_loss: bool = False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
r = geom_cu.closest_point_voxel(
query_sphere,
out_buffer,
grad_out_buffer,
sparsity_idx,
weight,
activation_distance,
max_distance,
grid_features,
grid_params,
grid_pose,
grid_enable,
n_env_grid,
env_query_idx,
max_nobs,
batch_size,
horizon,
n_spheres,
transform_back,
compute_distance,
use_batch_env,
sum_collisions,
compute_esdf,
)
ctx.compute_esdf = compute_esdf
ctx.return_loss = return_loss
ctx.save_for_backward(r[1])
return r[0]
@staticmethod
def backward(ctx, grad_output):
grad_pt = None
if ctx.needs_input_grad[0]:
# if ctx.compute_esdf:
# raise NotImplementedError("Gradients not implemented for compute_esdf=True")
(r,) = ctx.saved_tensors
if ctx.return_loss:
r = r * grad_output.unsqueeze(-1)
grad_pt = r
return (
grad_pt,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class SdfSweptSphereVoxel(torch.autograd.Function):
@staticmethod
def forward(
ctx,
query_sphere,
out_buffer,
grad_out_buffer,
sparsity_idx,
weight,
activation_distance,
max_distance,
speed_dt,
grid_features,
grid_params,
grid_pose,
grid_enable,
n_env_grid,
env_query_idx,
max_nobs,
batch_size,
horizon,
n_spheres,
sweep_steps,
enable_speed_metric,
transform_back,
compute_distance,
use_batch_env,
return_loss: bool = False,
sum_collisions: bool = True,
):
r = geom_cu.swept_closest_point_voxel(
query_sphere,
out_buffer,
grad_out_buffer,
sparsity_idx,
weight,
activation_distance,
max_distance,
speed_dt,
grid_features,
grid_params,
grid_pose,
grid_enable,
n_env_grid,
env_query_idx,
max_nobs,
batch_size,
horizon,
n_spheres,
sweep_steps,
enable_speed_metric,
transform_back,
compute_distance,
use_batch_env,
sum_collisions,
)
ctx.return_loss = return_loss
ctx.save_for_backward(
r[1],
)
return r[0]
@staticmethod
def backward(ctx, grad_output):
grad_pt = None
if ctx.needs_input_grad[0]:
(r,) = ctx.saved_tensors
if ctx.return_loss:
r = r * grad_output.unsqueeze(-1)
grad_pt = r
return (
grad_pt,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)

View File

@@ -39,7 +39,8 @@ except ImportError:
class LBFGScu(Function):
@staticmethod
def _call_cuda(
def forward(
ctx,
step_vec,
rho_buffer,
y_buffer,
@@ -50,6 +51,7 @@ class LBFGScu(Function):
grad_0,
epsilon=0.1,
stable_mode=False,
use_shared_buffers=True,
):
m, b, v_dim, _ = y_buffer.shape
@@ -67,39 +69,12 @@ class LBFGScu(Function):
m,
v_dim,
stable_mode,
use_shared_buffers,
)
step_v = R[0].view(step_vec.shape)
return step_v
@staticmethod
def forward(
ctx,
step_vec,
rho_buffer,
y_buffer,
s_buffer,
q,
grad_q,
x_0,
grad_0,
epsilon=0.1,
stable_mode=False,
):
R = LBFGScu._call_cuda(
step_vec,
rho_buffer,
y_buffer,
s_buffer,
q,
grad_q,
x_0,
grad_0,
epsilon=epsilon,
stable_mode=stable_mode,
)
# ctx.save_for_backward(batch_spheres, robot_spheres, link_mats, link_sphere_map)
return R
return step_v
@staticmethod
def backward(ctx, grad_output):
@@ -109,4 +84,5 @@ class LBFGScu(Function):
None,
None,
None,
None,
)

View File

@@ -12,8 +12,11 @@
# Third Party
import torch
# CuRobo
from curobo.util.torch_utils import get_torch_jit_decorator
@torch.jit.script
@get_torch_jit_decorator()
def project_depth_to_pointcloud(depth_image: torch.Tensor, intrinsics_matrix: torch.Tensor):
"""Projects numpy depth image to point cloud.
@@ -43,7 +46,7 @@ def project_depth_to_pointcloud(depth_image: torch.Tensor, intrinsics_matrix: to
return raw_pc
@torch.jit.script
@get_torch_jit_decorator()
def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor):
"""Projects numpy depth image to point cloud.
@@ -54,10 +57,10 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
Returns:
array of float (h, w, 3)
"""
fx = intrinsics_matrix[:, 0, 0]
fy = intrinsics_matrix[:, 1, 1]
cx = intrinsics_matrix[:, 0, 2]
cy = intrinsics_matrix[:, 1, 2]
fx = intrinsics_matrix[:, 0:1, 0:1]
fy = intrinsics_matrix[:, 1:2, 1:2]
cx = intrinsics_matrix[:, 0:1, 2:3]
cy = intrinsics_matrix[:, 1:2, 2:3]
input_x = torch.arange(width, dtype=torch.float32, device=intrinsics_matrix.device)
input_y = torch.arange(height, dtype=torch.float32, device=intrinsics_matrix.device)
@@ -73,7 +76,6 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
device=intrinsics_matrix.device,
dtype=torch.float32,
)
output_x = (input_x - cx) / fx
output_y = (input_y - cy) / fy
@@ -84,7 +86,7 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
return rays
@torch.jit.script
@get_torch_jit_decorator()
def project_pointcloud_to_depth(
pointcloud: torch.Tensor,
output_image: torch.Tensor,
@@ -106,7 +108,7 @@ def project_pointcloud_to_depth(
return output_image
@torch.jit.script
@get_torch_jit_decorator()
def project_depth_using_rays(
depth_image: torch.Tensor, rays: torch.Tensor, filter_origin: bool = False
):

View File

@@ -11,8 +11,10 @@
# Third Party
import torch
# from curobo.util.torch_utils import get_torch_jit_decorator
# @torch.jit.script
# @get_torch_jit_decorator()
def lookup_distance(pt, dist_matrix_flat, num_voxels):
# flatten:
ind_pt = (
@@ -22,7 +24,7 @@ def lookup_distance(pt, dist_matrix_flat, num_voxels):
return dist
# @torch.jit.script
# @get_torch_jit_decorator()
def compute_sdf_gradient(pt, dist_matrix_flat, num_voxels, dist):
grad_l = []
for i in range(3): # x,y,z

View File

@@ -30,6 +30,11 @@ def create_collision_checker(config: WorldCollisionConfig):
from curobo.geom.sdf.world_mesh import WorldMeshCollision
return WorldMeshCollision(config)
elif config.checker_type == CollisionCheckerType.VOXEL:
# CuRobo
from curobo.geom.sdf.world_voxel import WorldVoxelCollision
return WorldVoxelCollision(config)
else:
log_error("Not implemented", exc_info=True)
log_error("Unknown Collision Checker type: " + config.checker_type, exc_info=True)
raise NotImplementedError

View File

@@ -16,284 +16,6 @@ import warp as wp
wp.set_module_options({"fast_math": False})
# create warp kernels:
@wp.kernel
def get_swept_closest_pt(
pt: wp.array(dtype=wp.vec4),
distance: wp.array(dtype=wp.float32), # this stores the output cost
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
sparsity_idx: wp.array(dtype=wp.uint8),
weight: wp.array(dtype=wp.float32),
activation_distance: wp.array(dtype=wp.float32), # eta threshold
speed_dt: wp.array(dtype=wp.float32),
mesh: wp.array(dtype=wp.uint64),
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.float32,
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
sweep_steps: wp.uint8,
enable_speed_metric: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
# distance is negative outside and positive inside
tid = int(0)
tid = wp.tid()
b_idx = int(0)
h_idx = int(0)
sph_idx = int(0)
# read horizon
eta = float(0.0) # 5cm buffer
dt = float(1.0)
b_idx = tid / (horizon * nspheres)
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
return
n_mesh = int(0)
# $wp.printf("%d, %d, %d, %d \n", tid, b_idx, h_idx, sph_idx)
# read sphere
sphere_0_distance = float(0.0)
sphere_2_distance = float(0.0)
sphere_0 = wp.vec3(0.0)
sphere_2 = wp.vec3(0.0)
sphere_int = wp.vec3(0.0)
sphere_temp = wp.vec3(0.0)
k0 = float(0.0)
face_index = int(0)
face_u = float(0.0)
face_v = float(0.0)
sign = float(0.0)
dist = float(0.0)
uint_zero = wp.uint8(0)
uint_one = wp.uint8(1)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
in_rad = in_sphere[3]
if in_rad < 0.0:
distance[tid] = 0.0
if write_grad == 1 and sparsity_idx[tid] == uint_one:
sparsity_idx[tid] = uint_zero
closest_pt[tid * 4] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
eta = activation_distance[0]
dt = speed_dt[0]
in_rad += eta
max_dist_buffer = float(0.0)
max_dist_buffer = max_dist
if in_rad > max_dist_buffer:
max_dist_buffer += in_rad
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
# read in sphere 0:
# read in sphere 0:
if h_idx > 0:
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx - 1) * nspheres) + sph_idx]
sphere_0 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
sphere_0_distance = wp.length(sphere_0 - in_pt) / 2.0
if h_idx < horizon - 1:
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx + 1) * nspheres) + sph_idx]
sphere_2 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
sphere_2_distance = wp.length(sphere_2 - in_pt) / 2.0
# read in sphere 2:
closest_distance = float(0.0)
closest_point = wp.vec3(0.0)
i = int(0)
dis_length = float(0.0)
jump_distance = float(0.0)
mid_distance = float(0.0)
n_mesh = n_env_mesh[0]
obj_position = wp.vec3()
while i < n_mesh:
if mesh_enable[i] == uint_one:
obj_position[0] = mesh_pose[i * 8 + 0]
obj_position[1] = mesh_pose[i * 8 + 1]
obj_position[2] = mesh_pose[i * 8 + 2]
obj_quat = wp.quaternion(
mesh_pose[i * 8 + 4],
mesh_pose[i * 8 + 5],
mesh_pose[i * 8 + 6],
mesh_pose[i * 8 + 3],
)
obj_w_pose = wp.transform(obj_position, obj_quat)
obj_w_pose_t = wp.transform_inverse(obj_w_pose)
# transform point to mesh frame:
# mesh_pt = T_inverse @ w_pt
local_pt = wp.transform_point(obj_w_pose, in_pt)
if wp.mesh_query_point(
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - local_pt)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - local_pt) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
closest_distance += dist_metric
closest_point += grad_vec
else:
dist = -1.0 * dist
else:
dist = in_rad
dist = max(dist - in_rad, in_rad)
mid_distance = dist
# transform sphere -1
if h_idx > 0 and mid_distance < sphere_0_distance:
jump_distance = mid_distance
j = int(0)
sphere_temp = wp.transform_point(obj_w_pose, sphere_0)
while j < sweep_steps:
k0 = (
1.0 - 0.5 * jump_distance / sphere_0_distance
) # dist could be greater than sphere_0_distance here?
sphere_int = k0 * local_pt + ((1.0 - k0) * sphere_temp)
if wp.mesh_query_point(
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - sphere_int)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - sphere_int) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
closest_distance += dist_metric
closest_point += grad_vec
dist = max(dist - in_rad, in_rad)
jump_distance += dist
else:
dist = max(-dist - in_rad, in_rad)
jump_distance += dist
else:
jump_distance += max_dist_buffer
j += 1
if jump_distance >= sphere_0_distance:
j = int(sweep_steps)
# transform sphere -1
if h_idx < horizon - 1 and mid_distance < sphere_2_distance:
jump_distance = mid_distance
j = int(0)
sphere_temp = wp.transform_point(obj_w_pose, sphere_2)
while j < sweep_steps:
k0 = (
1.0 - 0.5 * jump_distance / sphere_2_distance
) # dist could be greater than sphere_0_distance here?
sphere_int = k0 * local_pt + (1.0 - k0) * sphere_temp
if wp.mesh_query_point(
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - sphere_int)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - sphere_int) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
closest_distance += dist_metric
closest_point += grad_vec
dist = max(dist - in_rad, in_rad)
jump_distance += dist
else:
dist = max(-dist - in_rad, in_rad)
jump_distance += dist
else:
jump_distance += max_dist_buffer
j += 1
if jump_distance >= sphere_2_distance:
j = int(sweep_steps)
i += 1
# return
if closest_distance == 0:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
distance[tid] = 0.0
if write_grad == 1:
closest_pt[tid * 4 + 0] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
if enable_speed_metric == 1 and (h_idx > 0 and h_idx < horizon - 1):
# calculate sphere velocity and acceleration:
norm_vel_vec = wp.vec3(0.0)
sph_acc_vec = wp.vec3(0.0)
sph_vel = wp.float(0.0)
# use central difference
norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
# norm_vel_vec = -1.0 * norm_vel_vec
# sph_acc_vec = -1.0 * sph_acc_vec
sph_vel = wp.length(norm_vel_vec)
norm_vel_vec = norm_vel_vec / sph_vel
orth_proj = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) - wp.outer(
norm_vel_vec, norm_vel_vec
)
curvature_vec = orth_proj * (sph_acc_vec / (sph_vel * sph_vel))
closest_point = sph_vel * ((orth_proj * closest_point) - closest_distance * curvature_vec)
closest_distance = sph_vel * closest_distance
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
if closest_distance > 0.0:
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
@wp.kernel
def get_swept_closest_pt_batch_env(
pt: wp.array(dtype=wp.vec4),
@@ -307,7 +29,7 @@ def get_swept_closest_pt_batch_env(
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.float32,
max_dist: wp.array(dtype=wp.float32),
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
@@ -316,6 +38,7 @@ def get_swept_closest_pt_batch_env(
sweep_steps: wp.uint8,
enable_speed_metric: wp.uint8,
env_query_idx: wp.array(dtype=wp.int32),
use_batch_env: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
@@ -357,6 +80,7 @@ def get_swept_closest_pt_batch_env(
sign = float(0.0)
dist = float(0.0)
dist_metric = float(0.0)
euclidean_distance = float(0.0)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
@@ -374,7 +98,7 @@ def get_swept_closest_pt_batch_env(
eta = activation_distance[0]
in_rad += eta
max_dist_buffer = float(0.0)
max_dist_buffer = max_dist
max_dist_buffer = max_dist[0]
if (in_rad) > max_dist_buffer:
max_dist_buffer += in_rad
@@ -396,7 +120,8 @@ def get_swept_closest_pt_batch_env(
dis_length = float(0.0)
jump_distance = float(0.0)
mid_distance = float(0.0)
env_idx = env_query_idx[b_idx]
if use_batch_env:
env_idx = env_query_idx[b_idx]
i = max_nmesh * env_idx
n_mesh = i + n_env_mesh[env_idx]
obj_position = wp.vec3()
@@ -423,26 +148,33 @@ def get_swept_closest_pt_batch_env(
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - local_pt)
delta = cl_pt - local_pt
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - local_pt) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
euclidean_distance = dist
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
cl_pt = (1.0 / eta) * dist * cl_pt
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_distance += dist_metric
closest_point += grad_vec
else:
dist = -1.0 * dist
euclidean_distance = dist
else:
dist = max_dist_buffer
dist = max(dist - in_rad, in_rad)
euclidean_distance = dist
dist = max(euclidean_distance - in_rad, in_rad)
mid_distance = dist
mid_distance = euclidean_distance
# transform sphere -1
if h_idx > 0 and mid_distance < sphere_0_distance:
jump_distance = mid_distance
@@ -457,24 +189,31 @@ def get_swept_closest_pt_batch_env(
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - sphere_int)
delta = cl_pt - sphere_int
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - sphere_int) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
euclidean_distance = dist
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
cl_pt = (1.0 / eta) * dist * cl_pt
closest_distance += dist_metric
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_point += grad_vec
dist = max(dist - in_rad, in_rad)
jump_distance += dist
dist = max(euclidean_distance - in_rad, in_rad)
jump_distance += euclidean_distance
else:
dist = max(-dist - in_rad, in_rad)
jump_distance += dist
euclidean_distance = dist
jump_distance += euclidean_distance
else:
jump_distance += max_dist_buffer
j += 1
@@ -495,24 +234,30 @@ def get_swept_closest_pt_batch_env(
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - sphere_int)
delta = cl_pt - sphere_int
dis_length = wp.length(delta)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - sphere_int) / dis_length
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
euclidean_distance = dist
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
# cl_pt = sign * (delta) / dis_length
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
cl_pt = (1.0 / eta) * dist * cl_pt
closest_distance += dist_metric
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
closest_point += grad_vec
dist = max(dist - in_rad, in_rad)
dist = max(euclidean_distance - in_rad, in_rad)
jump_distance += dist
else:
dist = max(-dist - in_rad, in_rad)
jump_distance += dist
else:
jump_distance += max_dist_buffer
@@ -542,179 +287,54 @@ def get_swept_closest_pt_batch_env(
# use central difference
norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
# norm_vel_vec = -1.0 * norm_vel_vec
# sph_acc_vec = -1.0 * sph_acc_vec
sph_vel = wp.length(norm_vel_vec)
if sph_vel > 1e-3:
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
norm_vel_vec = norm_vel_vec / sph_vel
norm_vel_vec = norm_vel_vec * (1.0 / sph_vel)
orth_proj = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) - wp.outer(
norm_vel_vec, norm_vel_vec
)
curvature_vec = sph_acc_vec / (sph_vel * sph_vel)
curvature_vec = orth_proj * (sph_acc_vec / (sph_vel * sph_vel))
orth_proj = wp.mat33(0.0)
for i in range(3):
for j in range(3):
orth_proj[i, j] = -1.0 * norm_vel_vec[i] * norm_vel_vec[j]
closest_point = sph_vel * ((orth_proj * closest_point) - closest_distance * curvature_vec)
orth_proj[0, 0] = orth_proj[0, 0] + 1.0
orth_proj[1, 1] = orth_proj[1, 1] + 1.0
orth_proj[2, 2] = orth_proj[2, 2] + 1.0
closest_distance = sph_vel * closest_distance
orth_curv = wp.vec3(
0.0, 0.0, 0.0
) # closest_distance * (orth_proj @ curvature_vec) #wp.matmul(orth_proj, curvature_vec)
orth_pt = wp.vec3(0.0, 0.0, 0.0) # orth_proj @ closest_point
for i in range(3):
orth_pt[i] = (
orth_proj[i, 0] * closest_point[0]
+ orth_proj[i, 1] * closest_point[1]
+ orth_proj[i, 2] * closest_point[2]
)
orth_curv[i] = closest_distance * (
orth_proj[i, 0] * curvature_vec[0]
+ orth_proj[i, 1] * curvature_vec[1]
+ orth_proj[i, 2] * curvature_vec[2]
)
closest_point = sph_vel * (orth_pt - orth_curv)
closest_distance = sph_vel * closest_distance
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
if closest_distance > 0.0:
closest_distance = weight[0]
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
@wp.kernel
def get_closest_pt(
pt: wp.array(dtype=wp.vec4),
distance: wp.array(dtype=wp.float32), # this stores the output cost
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
sparsity_idx: wp.array(dtype=wp.uint8),
weight: wp.array(dtype=wp.float32),
activation_distance: wp.array(dtype=wp.float32), # eta threshold
mesh: wp.array(dtype=wp.uint64),
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.float32,
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
):
# we launch nspheres kernels
# compute gradient here and return
# distance is negative outside and positive inside
tid = wp.tid()
n_mesh = int(0)
b_idx = int(0)
h_idx = int(0)
sph_idx = int(0)
# env_idx = int(0)
b_idx = tid / (horizon * nspheres)
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
return
face_index = int(0)
face_u = float(0.0)
face_v = float(0.0)
sign = float(0.0)
dist = float(0.0)
grad_vec = wp.vec3(0.0)
eta = float(0.05)
dist_metric = float(0.0)
cl_pt = wp.vec3(0.0)
local_pt = wp.vec3(0.0)
in_sphere = pt[tid]
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
in_rad = in_sphere[3]
uint_zero = wp.uint8(0)
uint_one = wp.uint8(1)
if in_rad < 0.0:
distance[tid] = 0.0
if write_grad == 1 and sparsity_idx[tid] == uint_one:
sparsity_idx[tid] = uint_zero
closest_pt[tid * 4] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
return
eta = activation_distance[0]
in_rad += eta
max_dist_buffer = float(0.0)
max_dist_buffer = max_dist
if in_rad > max_dist_buffer:
max_dist_buffer += in_rad
# TODO: read vec4 and use first 3 for sphere position and last one for radius
# in_pt = pt[tid]
closest_distance = float(0.0)
closest_point = wp.vec3(0.0)
i = int(0)
dis_length = float(0.0)
# read env index:
# env_idx = env_query_idx[b_idx]
# read number of boxes in current environment:
# get start index
i = int(0)
n_mesh = n_env_mesh[0]
obj_position = wp.vec3()
# mesh_idx = wp.uint64(0)
while i < n_mesh:
if mesh_enable[i] == uint_one:
# transform point to mesh frame:
# mesh_pt = T_inverse @ w_pt
# read object pose:
obj_position[0] = mesh_pose[i * 8 + 0]
obj_position[1] = mesh_pose[i * 8 + 1]
obj_position[2] = mesh_pose[i * 8 + 2]
obj_quat = wp.quaternion(
mesh_pose[i * 8 + 4],
mesh_pose[i * 8 + 5],
mesh_pose[i * 8 + 6],
mesh_pose[i * 8 + 3],
)
obj_w_pose = wp.transform(obj_position, obj_quat)
local_pt = wp.transform_point(obj_w_pose, in_pt)
# mesh_idx = mesh[i]
if wp.mesh_query_point(
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - local_pt)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - local_pt) / dis_length
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
closest_distance += dist_metric
closest_point += grad_vec
i += 1
if closest_distance == 0:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
distance[tid] = 0.0
if write_grad == 1:
closest_pt[tid * 4 + 0] = 0.0
closest_pt[tid * 4 + 1] = 0.0
closest_pt[tid * 4 + 2] = 0.0
else:
distance[tid] = weight[0] * closest_distance
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
if closest_distance > 0.0:
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
@wp.kernel
def get_closest_pt_batch_env(
pt: wp.array(dtype=wp.vec4),
@@ -727,13 +347,15 @@ def get_closest_pt_batch_env(
mesh_pose: wp.array(dtype=wp.float32),
mesh_enable: wp.array(dtype=wp.uint8),
n_env_mesh: wp.array(dtype=wp.int32),
max_dist: wp.float32,
max_dist: wp.array(dtype=wp.float32),
write_grad: wp.uint8,
batch_size: wp.int32,
horizon: wp.int32,
nspheres: wp.int32,
max_nmesh: wp.int32,
env_query_idx: wp.array(dtype=wp.int32),
use_batch_env: wp.uint8,
compute_esdf: wp.uint8,
):
# we launch nspheres kernels
# compute gradient here and return
@@ -779,8 +401,9 @@ def get_closest_pt_batch_env(
return
eta = activation_distance[0]
in_rad += eta
max_dist_buffer = max_dist
if compute_esdf != 1:
in_rad += eta
max_dist_buffer = max_dist[0]
if (in_rad) > max_dist_buffer:
max_dist_buffer += in_rad
@@ -791,7 +414,9 @@ def get_closest_pt_batch_env(
dis_length = float(0.0)
# read env index:
env_idx = env_query_idx[b_idx]
if use_batch_env:
env_idx = env_query_idx[b_idx]
# read number of boxes in current environment:
# get start index
@@ -799,7 +424,9 @@ def get_closest_pt_batch_env(
i = max_nmesh * env_idx
n_mesh = i + n_env_mesh[env_idx]
obj_position = wp.vec3()
max_dist_value = -1.0 * max_dist_buffer
if compute_esdf == 1:
closest_distance = max_dist_value
while i < n_mesh:
if mesh_enable[i] == uint_one:
# transform point to mesh frame:
@@ -822,21 +449,39 @@ def get_closest_pt_batch_env(
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
):
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
dis_length = wp.length(cl_pt - local_pt)
dist = (-1.0 * dis_length * sign) + in_rad
if dist > 0:
cl_pt = sign * (cl_pt - local_pt) / dis_length
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
grad_vec = (1.0 / eta) * dist * grad_vec
closest_distance += dist_metric
closest_point += grad_vec
delta = cl_pt - local_pt
dis_length = wp.length(delta)
dist = -1.0 * dis_length * sign
if compute_esdf == 1:
if dist > max_dist_value:
max_dist_value = dist
closest_distance = dist
if write_grad == 1:
cl_pt = sign * (delta) / dis_length
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
closest_point = grad_vec
else:
dist = dist + in_rad
if dist > 0:
if dist == in_rad:
cl_pt = sign * (delta) / (dist)
else:
cl_pt = sign * (delta) / dis_length
if dist > eta:
dist_metric = dist - 0.5 * eta
elif dist <= eta:
dist_metric = (0.5 / eta) * (dist) * dist
cl_pt = (1.0 / eta) * dist * cl_pt
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
closest_distance += dist_metric
closest_point += grad_vec
i += 1
if closest_distance == 0:
if closest_distance == 0 and compute_esdf != 1:
if sparsity_idx[tid] == uint_zero:
return
sparsity_idx[tid] = uint_zero
@@ -850,8 +495,7 @@ def get_closest_pt_batch_env(
sparsity_idx[tid] = uint_one
if write_grad == 1:
# compute gradient:
if closest_distance > 0.0:
closest_distance = weight[0]
closest_distance = weight[0]
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
@@ -871,62 +515,43 @@ class SdfMeshWarpPy(torch.autograd.Function):
mesh_pose_inverse,
mesh_enable,
n_env_mesh,
max_dist=0.05,
max_dist,
env_query_idx=None,
return_loss=False,
compute_esdf=False,
):
b, h, n, _ = query_spheres.shape
use_batch_env = True
if env_query_idx is None:
# launch
wp.launch(
kernel=get_closest_pt,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
max_dist,
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
],
stream=wp.stream_from_torch(query_spheres.device),
)
else:
wp.launch(
kernel=get_closest_pt_batch_env,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
max_dist,
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
],
stream=wp.stream_from_torch(query_spheres.device),
)
use_batch_env = False
env_query_idx = n_env_mesh
wp.launch(
kernel=get_closest_pt_batch_env,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
wp.from_torch(max_dist, dtype=wp.float32),
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
use_batch_env,
compute_esdf,
],
stream=wp.stream_from_torch(query_spheres.device),
)
ctx.return_loss = return_loss
ctx.save_for_backward(out_grad)
return out_cost
@@ -939,7 +564,22 @@ class SdfMeshWarpPy(torch.autograd.Function):
grad_sph = r
if ctx.return_loss:
grad_sph = r * grad_output.unsqueeze(-1)
return grad_sph, None, None, None, None, None, None, None, None, None, None, None, None
return (
grad_sph,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class SweptSdfMeshWarpPy(torch.autograd.Function):
@@ -957,69 +597,46 @@ class SweptSdfMeshWarpPy(torch.autograd.Function):
mesh_pose_inverse,
mesh_enable,
n_env_mesh,
max_dist,
sweep_steps=1,
enable_speed_metric=False,
max_dist=0.05,
env_query_idx=None,
return_loss=False,
):
b, h, n, _ = query_spheres.shape
use_batch_env = True
if env_query_idx is None:
wp.launch(
kernel=get_swept_closest_pt,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(speed_dt),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
max_dist,
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
sweep_steps,
enable_speed_metric,
],
stream=wp.stream_from_torch(query_spheres.device),
)
else:
wp.launch(
kernel=get_swept_closest_pt_batch_env,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(speed_dt),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
max_dist,
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
sweep_steps,
enable_speed_metric,
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
],
stream=wp.stream_from_torch(query_spheres.device),
)
use_batch_env = False
env_query_idx = n_env_mesh
wp.launch(
kernel=get_swept_closest_pt_batch_env,
dim=b * h * n,
inputs=[
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
wp.from_torch(out_cost.view(-1)),
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
wp.from_torch(weight),
wp.from_torch(activation_distance),
wp.from_torch(speed_dt),
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
wp.from_torch(max_dist, dtype=wp.float32),
query_spheres.requires_grad,
b,
h,
n,
mesh_idx.shape[1],
sweep_steps,
enable_speed_metric,
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
use_batch_env,
],
stream=wp.stream_from_torch(query_spheres.device),
)
ctx.return_loss = return_loss
ctx.save_for_backward(out_grad)
return out_cost

View File

@@ -19,7 +19,7 @@ import torch
# CuRobo
from curobo.curobolib.geom import SdfSphereOBB, SdfSweptSphereOBB
from curobo.geom.types import Cuboid, Mesh, Obstacle, WorldConfig, batch_tensor_cube
from curobo.geom.types import Cuboid, Mesh, Obstacle, VoxelGrid, WorldConfig, batch_tensor_cube
from curobo.types.base import TensorDeviceType
from curobo.types.math import Pose
from curobo.util.logger import log_error, log_info, log_warn
@@ -39,10 +39,14 @@ class CollisionBuffer:
def initialize_from_shape(cls, shape: torch.Size, tensor_args: TensorDeviceType):
batch, horizon, n_spheres, _ = shape
distance_buffer = torch.zeros(
(batch, horizon, n_spheres), device=tensor_args.device, dtype=tensor_args.dtype
(batch, horizon, n_spheres),
device=tensor_args.device,
dtype=tensor_args.collision_distance_dtype,
)
grad_distance_buffer = torch.zeros(
(batch, horizon, n_spheres, 4), device=tensor_args.device, dtype=tensor_args.dtype
(batch, horizon, n_spheres, 4),
device=tensor_args.device,
dtype=tensor_args.collision_gradient_dtype,
)
sparsity_idx = torch.zeros(
(batch, horizon, n_spheres),
@@ -54,10 +58,14 @@ class CollisionBuffer:
def _update_from_shape(self, shape: torch.Size, tensor_args: TensorDeviceType):
batch, horizon, n_spheres, _ = shape
self.distance_buffer = torch.zeros(
(batch, horizon, n_spheres), device=tensor_args.device, dtype=tensor_args.dtype
(batch, horizon, n_spheres),
device=tensor_args.device,
dtype=tensor_args.collision_distance_dtype,
)
self.grad_distance_buffer = torch.zeros(
(batch, horizon, n_spheres, 4), device=tensor_args.device, dtype=tensor_args.dtype
(batch, horizon, n_spheres, 4),
device=tensor_args.device,
dtype=tensor_args.collision_gradient_dtype,
)
self.sparsity_index_buffer = torch.zeros(
(batch, horizon, n_spheres),
@@ -100,6 +108,7 @@ class CollisionQueryBuffer:
primitive_collision_buffer: Optional[CollisionBuffer] = None
mesh_collision_buffer: Optional[CollisionBuffer] = None
blox_collision_buffer: Optional[CollisionBuffer] = None
voxel_collision_buffer: Optional[CollisionBuffer] = None
shape: Optional[torch.Size] = None
def __post_init__(self):
@@ -110,6 +119,8 @@ class CollisionQueryBuffer:
self.shape = self.mesh_collision_buffer.shape
elif self.blox_collision_buffer is not None:
self.shape = self.blox_collision_buffer.shape
elif self.voxel_collision_buffer is not None:
self.shape = self.voxel_collision_buffer.shape
def __mul__(self, scalar: float):
if self.primitive_collision_buffer is not None:
@@ -118,17 +129,27 @@ class CollisionQueryBuffer:
self.mesh_collision_buffer = self.mesh_collision_buffer * scalar
if self.blox_collision_buffer is not None:
self.blox_collision_buffer = self.blox_collision_buffer * scalar
if self.voxel_collision_buffer is not None:
self.voxel_collision_buffer = self.voxel_collision_buffer * scalar
return self
def clone(self):
prim_buffer = mesh_buffer = blox_buffer = None
prim_buffer = mesh_buffer = blox_buffer = voxel_buffer = None
if self.primitive_collision_buffer is not None:
prim_buffer = self.primitive_collision_buffer.clone()
if self.mesh_collision_buffer is not None:
mesh_buffer = self.mesh_collision_buffer.clone()
if self.blox_collision_buffer is not None:
blox_buffer = self.blox_collision_buffer.clone()
return CollisionQueryBuffer(prim_buffer, mesh_buffer, blox_buffer, self.shape)
if self.voxel_collision_buffer is not None:
voxel_buffer = self.voxel_collision_buffer.clone()
return CollisionQueryBuffer(
prim_buffer,
mesh_buffer,
blox_buffer,
voxel_collision_buffer=voxel_buffer,
shape=self.shape,
)
@classmethod
def initialize_from_shape(
@@ -137,14 +158,18 @@ class CollisionQueryBuffer:
tensor_args: TensorDeviceType,
collision_types: Dict[str, bool],
):
primitive_buffer = mesh_buffer = blox_buffer = None
primitive_buffer = mesh_buffer = blox_buffer = voxel_buffer = None
if "primitive" in collision_types and collision_types["primitive"]:
primitive_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
if "mesh" in collision_types and collision_types["mesh"]:
mesh_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
if "blox" in collision_types and collision_types["blox"]:
blox_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
return CollisionQueryBuffer(primitive_buffer, mesh_buffer, blox_buffer)
if "voxel" in collision_types and collision_types["voxel"]:
voxel_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
return CollisionQueryBuffer(
primitive_buffer, mesh_buffer, blox_buffer, voxel_collision_buffer=voxel_buffer
)
def create_from_shape(
self,
@@ -160,8 +185,9 @@ class CollisionQueryBuffer:
self.mesh_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
if "blox" in collision_types and collision_types["blox"]:
self.blox_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
if "voxel" in collision_types and collision_types["voxel"]:
self.voxel_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
self.shape = shape
# return self
def update_buffer_shape(
self,
@@ -169,12 +195,10 @@ class CollisionQueryBuffer:
tensor_args: TensorDeviceType,
collision_types: Optional[Dict[str, bool]],
):
# print(shape, self.shape)
# update buffers:
assert len(shape) == 4 # shape is: batch, horizon, n_spheres, 4
if self.shape is None: # buffers not initialized:
self.create_from_shape(shape, tensor_args, collision_types)
# print("Creating new memory", self.shape)
else:
# update buffers if shape doesn't match:
# TODO: allow for dynamic change of collision_types
@@ -185,6 +209,8 @@ class CollisionQueryBuffer:
self.mesh_collision_buffer.update_buffer_shape(shape, tensor_args)
if self.blox_collision_buffer is not None:
self.blox_collision_buffer.update_buffer_shape(shape, tensor_args)
if self.voxel_collision_buffer is not None:
self.voxel_collision_buffer.update_buffer_shape(shape, tensor_args)
self.shape = shape
def get_gradient_buffer(
@@ -208,6 +234,12 @@ class CollisionQueryBuffer:
current_buffer = blox_buffer.clone()
else:
current_buffer += blox_buffer
if self.voxel_collision_buffer is not None:
voxel_buffer = self.voxel_collision_buffer.grad_distance_buffer
if current_buffer is None:
current_buffer = voxel_buffer.clone()
else:
current_buffer += voxel_buffer
return current_buffer
@@ -221,6 +253,7 @@ class CollisionCheckerType(Enum):
PRIMITIVE = "PRIMITIVE"
BLOX = "BLOX"
MESH = "MESH"
VOXEL = "VOXEL"
@dataclass
@@ -230,11 +263,13 @@ class WorldCollisionConfig:
cache: Optional[Dict[Obstacle, int]] = None
n_envs: int = 1
checker_type: CollisionCheckerType = CollisionCheckerType.PRIMITIVE
max_distance: float = 0.01
max_distance: Union[torch.Tensor, float] = 0.01
def __post_init__(self):
if self.world_model is not None and isinstance(self.world_model, list):
self.n_envs = len(self.world_model)
if isinstance(self.max_distance, float):
self.max_distance = self.tensor_args.to_device([self.max_distance])
@staticmethod
def load_from_dict(
@@ -261,6 +296,8 @@ class WorldCollision(WorldCollisionConfig):
if config is not None:
WorldCollisionConfig.__init__(self, **vars(config))
self.collision_types = {} # Use this dictionary to store collision types
self._cache_voxelization = None
self._cache_voxelization_collision_buffer = None
def load_collision_model(self, world_model: WorldConfig):
raise NotImplementedError
@@ -273,6 +310,8 @@ class WorldCollision(WorldCollisionConfig):
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
"""
Computes the signed distance via analytic function
@@ -310,6 +349,7 @@ class WorldCollision(WorldCollisionConfig):
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
):
raise NotImplementedError
@@ -338,6 +378,118 @@ class WorldCollision(WorldCollisionConfig):
):
raise NotImplementedError
def get_voxels_in_bounding_box(
self,
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
voxel_size: float = 0.02,
) -> Union[List[Cuboid], torch.Tensor]:
new_grid = self.get_occupancy_in_bounding_box(cuboid, voxel_size)
occupied = new_grid.get_occupied_voxels(0.0)
return occupied
def clear_voxelization_cache(self):
self._cache_voxelization = None
def update_cache_voxelization(self, new_grid: VoxelGrid):
if (
self._cache_voxelization is None
or self._cache_voxelization.voxel_size != new_grid.voxel_size
or self._cache_voxelization.dims != new_grid.dims
):
self._cache_voxelization = new_grid
self._cache_voxelization.xyzr_tensor = self._cache_voxelization.create_xyzr_tensor(
transform_to_origin=True, tensor_args=self.tensor_args
)
self._cache_voxelization_collision_buffer = CollisionQueryBuffer()
xyzr = self._cache_voxelization.xyzr_tensor.view(-1, 1, 1, 4)
self._cache_voxelization_collision_buffer.update_buffer_shape(
xyzr.shape,
self.tensor_args,
self.collision_types,
)
def get_occupancy_in_bounding_box(
self,
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
voxel_size: float = 0.02,
) -> VoxelGrid:
new_grid = VoxelGrid(
name=cuboid.name, dims=cuboid.dims, pose=cuboid.pose, voxel_size=voxel_size
)
self.update_cache_voxelization(new_grid)
xyzr = self._cache_voxelization.xyzr_tensor
xyzr = xyzr.view(-1, 1, 1, 4)
weight = self.tensor_args.to_device([1.0])
act_distance = self.tensor_args.to_device([0.0])
d_sph = self.get_sphere_collision(
xyzr,
self._cache_voxelization_collision_buffer,
weight,
act_distance,
)
d_sph = d_sph.reshape(-1)
new_grid.xyzr_tensor = self._cache_voxelization.xyzr_tensor
new_grid.feature_tensor = d_sph
return new_grid
def get_esdf_in_bounding_box(
self,
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
voxel_size: float = 0.02,
dtype=torch.float32,
) -> VoxelGrid:
new_grid = VoxelGrid(
name=cuboid.name,
dims=cuboid.dims,
pose=cuboid.pose,
voxel_size=voxel_size,
feature_dtype=dtype,
)
self.update_cache_voxelization(new_grid)
xyzr = self._cache_voxelization.xyzr_tensor
voxel_shape = xyzr.shape
xyzr = xyzr.view(-1, 1, 1, 4)
weight = self.tensor_args.to_device([1.0])
d_sph = self.get_sphere_distance(
xyzr,
self._cache_voxelization_collision_buffer,
weight,
self.max_distance,
sum_collisions=False,
compute_esdf=True,
)
d_sph = d_sph.reshape(-1)
voxel_grid = self._cache_voxelization
voxel_grid.feature_tensor = d_sph
return voxel_grid
def get_mesh_in_bounding_box(
self,
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
voxel_size: float = 0.02,
) -> Mesh:
voxels = self.get_voxels_in_bounding_box(cuboid, voxel_size)
# voxels = voxels.cpu().numpy()
# cuboids = [Cuboid(name="c_"+str(x), pose=[voxels[x,0],voxels[x,1],voxels[x,2], 1,0,0,0], dims=[voxel_size, voxel_size, voxel_size]) for x in range(voxels.shape[0])]
# mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0]
mesh = Mesh.from_pointcloud(
voxels[:, :3].detach().cpu().numpy(),
pitch=voxel_size * 1.1,
)
return mesh
class WorldPrimitiveCollision(WorldCollision):
"""World Oriented Bounding Box representation object
@@ -354,6 +506,7 @@ class WorldPrimitiveCollision(WorldCollision):
self._env_n_obbs = None
self._env_obbs_names = None
self._init_cache()
if self.world_model is not None:
if isinstance(self.world_model, list):
self.load_batch_collision_model(self.world_model)
@@ -656,6 +809,8 @@ class WorldPrimitiveCollision(WorldCollision):
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
raise ValueError("Primitive Collision has no obstacles")
@@ -673,6 +828,7 @@ class WorldPrimitiveCollision(WorldCollision):
collision_query_buffer.primitive_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self._cube_tensor_list[0],
self._cube_tensor_list[0],
self._cube_tensor_list[1],
@@ -687,6 +843,8 @@ class WorldPrimitiveCollision(WorldCollision):
True,
use_batch_env,
return_loss,
sum_collisions,
compute_esdf,
)
return dist
@@ -699,6 +857,7 @@ class WorldPrimitiveCollision(WorldCollision):
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
**kwargs,
):
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
raise ValueError("Primitive Collision has no obstacles")
@@ -717,6 +876,7 @@ class WorldPrimitiveCollision(WorldCollision):
collision_query_buffer.primitive_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self._cube_tensor_list[0],
self._cube_tensor_list[0],
self._cube_tensor_list[1],
@@ -728,7 +888,7 @@ class WorldPrimitiveCollision(WorldCollision):
h,
n,
query_sphere.requires_grad,
False,
True,
use_batch_env,
return_loss,
)
@@ -745,6 +905,7 @@ class WorldPrimitiveCollision(WorldCollision):
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
sum_collisions: bool = True,
):
"""
Computes the signed distance via analytic function
@@ -784,6 +945,7 @@ class WorldPrimitiveCollision(WorldCollision):
True,
use_batch_env,
return_loss,
sum_collisions,
)
return dist
@@ -836,7 +998,7 @@ class WorldPrimitiveCollision(WorldCollision):
sweep_steps,
enable_speed_metric,
query_sphere.requires_grad,
False,
True,
use_batch_env,
return_loss,
)
@@ -845,70 +1007,5 @@ class WorldPrimitiveCollision(WorldCollision):
def clear_cache(self):
if self._cube_tensor_list is not None:
self._cube_tensor_list[2][:] = 1
self._cube_tensor_list[2][:] = 0
self._env_n_obbs[:] = 0
def get_voxels_in_bounding_box(
self,
cuboid: Cuboid,
voxel_size: float = 0.02,
) -> Union[List[Cuboid], torch.Tensor]:
bounds = cuboid.dims
low = [-bounds[0], -bounds[1], -bounds[2]]
high = [bounds[0], bounds[1], bounds[2]]
trange = [h - l for l, h in zip(low, high)]
x = torch.linspace(
-bounds[0], bounds[0], int(trange[0] // voxel_size) + 1, device=self.tensor_args.device
)
y = torch.linspace(
-bounds[1], bounds[1], int(trange[1] // voxel_size) + 1, device=self.tensor_args.device
)
z = torch.linspace(
-bounds[2], bounds[2], int(trange[2] // voxel_size) + 1, device=self.tensor_args.device
)
w, l, h = x.shape[0], y.shape[0], z.shape[0]
xyz = (
torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3)
)
pose = Pose.from_list(cuboid.pose, tensor_args=self.tensor_args)
xyz = pose.transform_points(xyz.contiguous())
r = torch.zeros_like(xyz[:, 0:1]) + voxel_size
xyzr = torch.cat([xyz, r], dim=1)
xyzr = xyzr.reshape(-1, 1, 1, 4)
collision_buffer = CollisionQueryBuffer()
collision_buffer.update_buffer_shape(
xyzr.shape,
self.tensor_args,
self.collision_types,
)
weight = self.tensor_args.to_device([1.0])
act_distance = self.tensor_args.to_device([0.0])
d_sph = self.get_sphere_collision(
xyzr,
collision_buffer,
weight,
act_distance,
)
d_sph = d_sph.reshape(-1)
xyzr = xyzr.reshape(-1, 4)
# get occupied voxels:
occupied = xyzr[d_sph > 0.0]
return occupied
def get_mesh_in_bounding_box(
self,
cuboid: Cuboid,
voxel_size: float = 0.02,
) -> Mesh:
voxels = self.get_voxels_in_bounding_box(cuboid, voxel_size)
# voxels = voxels.cpu().numpy()
# cuboids = [Cuboid(name="c_"+str(x), pose=[voxels[x,0],voxels[x,1],voxels[x,2], 1,0,0,0], dims=[voxel_size, voxel_size, voxel_size]) for x in range(voxels.shape[0])]
# mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0]
mesh = Mesh.from_pointcloud(
voxels[:, :3].detach().cpu().numpy(),
pitch=voxel_size * 1.1,
)
return mesh

View File

@@ -176,6 +176,8 @@ class WorldBloxCollision(WorldMeshCollision):
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
if "blox" not in self.collision_types or not self.collision_types["blox"]:
return super().get_sphere_distance(
@@ -185,6 +187,8 @@ class WorldBloxCollision(WorldMeshCollision):
activation_distance,
env_query_idx,
return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
d = self._get_blox_sdf(
@@ -205,8 +209,13 @@ class WorldBloxCollision(WorldMeshCollision):
activation_distance,
env_query_idx,
return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
d = d + d_base
if compute_esdf:
d = torch.maximum(d, d_base)
else:
d = d + d_base
return d
@@ -262,6 +271,7 @@ class WorldBloxCollision(WorldMeshCollision):
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
):
if "blox" not in self.collision_types or not self.collision_types["blox"]:
return super().get_swept_sphere_distance(
@@ -274,6 +284,7 @@ class WorldBloxCollision(WorldMeshCollision):
enable_speed_metric,
env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
d = self._get_blox_swept_sdf(
@@ -301,6 +312,7 @@ class WorldBloxCollision(WorldMeshCollision):
enable_speed_metric,
env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
d = d + d_base

View File

@@ -89,6 +89,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
self._mesh_tensor_list[0][env_idx, :max_nmesh] = w_mid
self._mesh_tensor_list[1][env_idx, :max_nmesh, :7] = w_inv_pose
self._mesh_tensor_list[2][env_idx, :max_nmesh] = 1
self._mesh_tensor_list[2][env_idx, max_nmesh:] = 0
self._env_mesh_names[env_idx][:max_nmesh] = name_list
self._env_n_mesh[env_idx] = max_nmesh
@@ -355,6 +356,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance: torch.Tensor,
env_query_idx=None,
return_loss=False,
compute_esdf=False,
):
d = SdfMeshWarpPy.apply(
query_spheres,
@@ -370,6 +372,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
self.max_distance,
env_query_idx,
return_loss,
compute_esdf,
)
return d
@@ -397,9 +400,9 @@ class WorldMeshCollision(WorldPrimitiveCollision):
self._mesh_tensor_list[1],
self._mesh_tensor_list[2],
self._env_n_mesh,
self.max_distance,
sweep_steps,
enable_speed_metric,
self.max_distance,
env_query_idx,
return_loss,
)
@@ -413,6 +416,8 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
# TODO: if no mesh object exist, call primitive
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
@@ -423,6 +428,8 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
d = self._get_sdf(
@@ -432,6 +439,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
compute_esdf=compute_esdf,
)
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
@@ -443,8 +451,13 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
d_val = d.view(d_prim.shape) + d_prim
if compute_esdf:
d_val = torch.maximum(d.view(d_prim.shape), d_prim)
else:
d_val = d.view(d_prim.shape) + d_prim
return d_val
def get_sphere_collision(
@@ -455,6 +468,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
activation_distance: torch.Tensor,
env_query_idx=None,
return_loss=False,
**kwargs,
):
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
return super().get_sphere_collision(
@@ -501,6 +515,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss: bool = False,
sum_collisions: bool = True,
):
# log_warn("Swept: Mesh + Primitive Collision Checking is experimental")
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
@@ -514,6 +529,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
d = self._get_swept_sdf(
@@ -540,6 +556,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
d_val = d.view(d_prim.shape) + d_prim
@@ -602,4 +619,11 @@ class WorldMeshCollision(WorldPrimitiveCollision):
self._wp_mesh_cache = {}
if self._mesh_tensor_list is not None:
self._mesh_tensor_list[2][:] = 0
if self._env_n_mesh is not None:
self._env_n_mesh[:] = 0
if self._env_mesh_names is not None:
self._env_mesh_names = [
[None for _ in range(self.cache["mesh"])] for _ in range(self.n_envs)
]
super().clear_cache()

View File

@@ -0,0 +1,699 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
import math
from typing import Any, Dict, List, Optional
# Third Party
import numpy as np
import torch
# CuRobo
from curobo.curobolib.geom import SdfSphereVoxel, SdfSweptSphereVoxel
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollisionConfig
from curobo.geom.sdf.world_mesh import WorldMeshCollision
from curobo.geom.types import VoxelGrid, WorldConfig
from curobo.types.math import Pose
from curobo.util.logger import log_error, log_info, log_warn
class WorldVoxelCollision(WorldMeshCollision):
"""Voxel grid representation of World, with each voxel containing Euclidean Signed Distance."""
def __init__(self, config: WorldCollisionConfig):
self._env_n_voxels = None
self._voxel_tensor_list = None
self._env_voxel_names = None
super().__init__(config)
def _init_cache(self):
if (
self.cache is not None
and "voxel" in self.cache
and self.cache["voxel"] not in [None, 0]
):
self._create_voxel_cache(self.cache["voxel"])
def _create_voxel_cache(self, voxel_cache: Dict[str, Any]):
n_layers = voxel_cache["layers"]
dims = voxel_cache["dims"]
voxel_size = voxel_cache["voxel_size"]
feature_dtype = voxel_cache["feature_dtype"]
n_voxels = int(
math.floor(dims[0] / voxel_size)
* math.floor(dims[1] / voxel_size)
* math.floor(dims[2] / voxel_size)
)
voxel_params = torch.zeros(
(self.n_envs, n_layers, 4),
dtype=self.tensor_args.dtype,
device=self.tensor_args.device,
)
voxel_pose = torch.zeros(
(self.n_envs, n_layers, 8),
dtype=self.tensor_args.dtype,
device=self.tensor_args.device,
)
voxel_pose[..., 3] = 1.0
voxel_enable = torch.zeros(
(self.n_envs, n_layers), dtype=torch.uint8, device=self.tensor_args.device
)
self._env_n_voxels = torch.zeros(
(self.n_envs), device=self.tensor_args.device, dtype=torch.int32
)
voxel_features = torch.zeros(
(self.n_envs, n_layers, n_voxels, 1),
device=self.tensor_args.device,
dtype=feature_dtype,
)
self._voxel_tensor_list = [voxel_params, voxel_pose, voxel_enable, voxel_features]
self.collision_types["voxel"] = True
self._env_voxel_names = [[None for _ in range(n_layers)] for _ in range(self.n_envs)]
def load_collision_model(
self, world_model: WorldConfig, env_idx=0, fix_cache_reference: bool = False
):
self._load_collision_model_in_cache(
world_model, env_idx, fix_cache_reference=fix_cache_reference
)
return super().load_collision_model(
world_model, env_idx=env_idx, fix_cache_reference=fix_cache_reference
)
def _load_collision_model_in_cache(
self, world_config: WorldConfig, env_idx: int = 0, fix_cache_reference: bool = False
):
"""TODO:
_extended_summary_
Args:
world_config: _description_
env_idx: _description_
fix_cache_reference: _description_
"""
voxel_objs = world_config.voxel
max_obs = len(voxel_objs)
self.world_model = world_config
if max_obs < 1:
log_info("No Voxel objs")
return
if self._voxel_tensor_list is None or self._voxel_tensor_list[0].shape[1] < max_obs:
if not fix_cache_reference:
log_info("Creating Voxel cache" + str(max_obs))
self._create_voxel_cache(
{
"layers": max_obs,
"dims": voxel_objs[0].dims,
"voxel_size": voxel_objs[0].voxel_size,
"feature_dtype": voxel_objs[0].feature_dtype,
}
)
else:
log_error("number of OBB is larger than collision cache, create larger cache.")
# load as a batch:
pose_batch = [c.pose for c in voxel_objs]
dims_batch = [c.dims for c in voxel_objs]
names_batch = [c.name for c in voxel_objs]
size_batch = [c.voxel_size for c in voxel_objs]
voxel_batch = self._batch_tensor_voxel(pose_batch, dims_batch, size_batch)
self._voxel_tensor_list[0][env_idx, :max_obs, :] = voxel_batch[0]
self._voxel_tensor_list[1][env_idx, :max_obs, :7] = voxel_batch[1]
self._voxel_tensor_list[2][env_idx, :max_obs] = 1 # enabling obstacle
self._voxel_tensor_list[2][env_idx, max_obs:] = 0 # disabling obstacle
# copy voxel grid features:
self._env_n_voxels[env_idx] = max_obs
self._env_voxel_names[env_idx][:max_obs] = names_batch
self.collision_types["voxel"] = True
def _batch_tensor_voxel(
self, pose: List[List[float]], dims: List[float], voxel_size: List[float]
):
w_T_b = Pose.from_batch_list(pose, tensor_args=self.tensor_args)
b_T_w = w_T_b.inverse()
dims_t = torch.as_tensor(
np.array(dims), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
size_t = torch.as_tensor(
np.array(voxel_size), device=self.tensor_args.device, dtype=self.tensor_args.dtype
).unsqueeze(-1)
params_t = torch.cat([dims_t, size_t], dim=-1)
voxel_list = [params_t, b_T_w.get_pose_vector()]
return voxel_list
def load_batch_collision_model(self, world_config_list: List[WorldConfig]):
"""Load voxel grid for batched environments
_extended_summary_
Args:
world_config_list: _description_
Returns:
_description_
"""
log_error("Not Implemented")
# First find largest number of cuboid:
c_len = []
pose_batch = []
dims_batch = []
names_batch = []
vsize_batch = []
for i in world_config_list:
c = i.cuboid
if c is not None:
c_len.append(len(c))
pose_batch.extend([i.pose for i in c])
dims_batch.extend([i.dims for i in c])
names_batch.extend([i.name for i in c])
vsize_batch.extend([i.voxel_size for i in c])
else:
c_len.append(0)
max_obs = max(c_len)
if max_obs < 1:
log_warn("No obbs found")
return
# check if number of environments is same as config:
reset_buffers = False
if self._env_n_voxels is not None and len(world_config_list) != len(self._env_n_voxels):
log_warn(
"env_n_voxels is not same as world_config_list, reloading collision buffers (breaks CG)"
)
reset_buffers = True
self.n_envs = len(world_config_list)
self._env_n_voxels = torch.zeros(
(self.n_envs), device=self.tensor_args.device, dtype=torch.int32
)
if self._voxel_tensor_list is not None and self._voxel_tensor_list[0].shape[1] < max_obs:
log_warn(
"number of obbs is greater than buffer, reloading collision buffers (breaks CG)"
)
reset_buffers = True
# create cache if does not exist:
if self._voxel_tensor_list is None or reset_buffers:
log_info("Creating Obb cache" + str(max_obs))
self._create_obb_cache(max_obs)
# load obstacles:
## load data into gpu:
voxel_batch = self._batch_tensor_voxel(pose_batch, dims_batch, vsize_batch)
c_start = 0
for i in range(len(self._env_n_voxels)):
if c_len[i] > 0:
# load obb:
self._voxel_tensor_list[0][i, : c_len[i], :] = voxel_batch[0][
c_start : c_start + c_len[i]
]
self._voxel_tensor_list[1][i, : c_len[i], :7] = voxel_batch[1][
c_start : c_start + c_len[i]
]
self._voxel_tensor_list[2][i, : c_len[i]] = 1
self._env_voxel_names[i][: c_len[i]] = names_batch[c_start : c_start + c_len[i]]
self._voxel_tensor_list[2][i, c_len[i] :] = 0
c_start += c_len[i]
self._env_n_voxels[:] = torch.as_tensor(
c_len, dtype=torch.int32, device=self.tensor_args.device
)
self.collision_types["voxel"] = True
return super().load_batch_collision_model(world_config_list)
def enable_obstacle(
self,
name: str,
enable: bool = True,
env_idx: int = 0,
):
if self._env_voxel_names is not None and name in self._env_voxel_names[env_idx]:
self.enable_voxel(enable, name, None, env_idx)
else:
return super().enable_obstacle(name, enable, env_idx)
def enable_voxel(
self,
enable: bool = True,
name: Optional[str] = None,
env_obj_idx: Optional[torch.Tensor] = None,
env_idx: int = 0,
):
"""Update obstacle dimensions
Args:
obj_dims (torch.Tensor): [dim.x,dim.y, dim.z], give as [b,3]
obj_idx (torch.Tensor or int):
"""
if env_obj_idx is not None:
self._voxel_tensor_list[2][env_obj_idx] = int(enable) # enable == 1
else:
# find index of given name:
obs_idx = self.get_voxel_idx(name, env_idx)
self._voxel_tensor_list[2][env_idx, obs_idx] = int(enable)
def update_obstacle_pose(
self,
name: str,
w_obj_pose: Pose,
env_idx: int = 0,
):
if self._env_voxel_names is not None and name in self._env_voxel_names[env_idx]:
self.update_voxel_pose(name=name, w_obj_pose=w_obj_pose, env_idx=env_idx)
else:
log_error("obstacle not found in OBB world model: " + name)
def update_voxel_data(self, new_voxel: VoxelGrid, env_idx: int = 0):
obs_idx = self.get_voxel_idx(new_voxel.name, env_idx)
self._voxel_tensor_list[3][env_idx, obs_idx, :, :] = new_voxel.feature_tensor.view(
new_voxel.feature_tensor.shape[0], -1
).to(dtype=self._voxel_tensor_list[3].dtype)
self._voxel_tensor_list[0][env_idx, obs_idx, :3] = self.tensor_args.to_device(
new_voxel.dims
)
self._voxel_tensor_list[0][env_idx, obs_idx, 3] = new_voxel.voxel_size
self._voxel_tensor_list[1][env_idx, obs_idx, :7] = (
Pose.from_list(new_voxel.pose, self.tensor_args).inverse().get_pose_vector()
)
self._voxel_tensor_list[2][env_idx, obs_idx] = int(True)
def update_voxel_features(
self,
features: torch.Tensor,
name: Optional[str] = None,
env_obj_idx: Optional[torch.Tensor] = None,
env_idx: int = 0,
):
"""Update pose of a specific objects.
This also updates the signed distance grid to account for the updated object pose.
Args:
obj_w_pose: Pose
obj_idx:
"""
if env_obj_idx is not None:
self._voxel_tensor_list[3][env_obj_idx, :] = features.to(
dtype=self._voxel_tensor_list[3].dtype
)
else:
obs_idx = self.get_voxel_idx(name, env_idx)
self._voxel_tensor_list[3][env_idx, obs_idx, :] = features.to(
dtype=self._voxel_tensor_list[3].dtype
)
def update_voxel_pose(
self,
w_obj_pose: Optional[Pose] = None,
obj_w_pose: Optional[Pose] = None,
name: Optional[str] = None,
env_obj_idx: Optional[torch.Tensor] = None,
env_idx: int = 0,
):
"""Update pose of a specific objects.
This also updates the signed distance grid to account for the updated object pose.
Args:
obj_w_pose: Pose
obj_idx:
"""
obj_w_pose = self._get_obstacle_poses(w_obj_pose, obj_w_pose)
if env_obj_idx is not None:
self._voxel_tensor_list[1][env_obj_idx, :7] = obj_w_pose.get_pose_vector()
else:
obs_idx = self.get_voxel_idx(name, env_idx)
self._voxel_tensor_list[1][env_idx, obs_idx, :7] = obj_w_pose.get_pose_vector()
def get_voxel_idx(
self,
name: str,
env_idx: int = 0,
) -> int:
if name not in self._env_voxel_names[env_idx]:
log_error("Obstacle with name: " + name + " not found in current world", exc_info=True)
return self._env_voxel_names[env_idx].index(name)
def get_voxel_grid(
self,
name: str,
env_idx: int = 0,
):
obs_idx = self.get_voxel_idx(name, env_idx)
voxel_params = np.round(
self._voxel_tensor_list[0][env_idx, obs_idx, :].cpu().numpy().astype(np.float64), 6
).tolist()
voxel_pose = Pose(
position=self._voxel_tensor_list[1][env_idx, obs_idx, :3],
quaternion=self._voxel_tensor_list[1][env_idx, obs_idx, 3:7],
)
voxel_features = self._voxel_tensor_list[3][env_idx, obs_idx, :]
voxel_grid = VoxelGrid(
name=name,
dims=voxel_params[:3],
pose=voxel_pose.to_list(),
voxel_size=voxel_params[3],
feature_tensor=voxel_features,
)
return voxel_grid
def get_sphere_distance(
self,
query_sphere,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
sum_collisions: bool = True,
compute_esdf: bool = False,
):
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
return super().get_sphere_distance(
query_sphere,
collision_query_buffer,
weight,
activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
b, h, n, _ = query_sphere.shape # This can be read from collision query buffer
use_batch_env = True
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
dist = SdfSphereVoxel.apply(
query_sphere,
collision_query_buffer.voxel_collision_buffer.distance_buffer,
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
self._voxel_tensor_list[0].shape[1],
b,
h,
n,
query_sphere.requires_grad,
True,
use_batch_env,
return_loss,
sum_collisions,
compute_esdf,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_sphere_distance(
query_sphere,
collision_query_buffer,
weight=weight,
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
sum_collisions=sum_collisions,
compute_esdf=compute_esdf,
)
if compute_esdf:
d_val = torch.maximum(dist.view(d_prim.shape), d_prim)
else:
d_val = d_val.view(d_prim.shape) + d_prim
return d_val
def get_sphere_collision(
self,
query_sphere,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
**kwargs,
):
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
return super().get_sphere_collision(
query_sphere,
collision_query_buffer,
weight,
activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
)
if return_loss:
raise ValueError("cannot return loss for classification, use get_sphere_distance")
b, h, n, _ = query_sphere.shape
use_batch_env = True
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
dist = SdfSphereVoxel.apply(
query_sphere,
collision_query_buffer.voxel_collision_buffer.distance_buffer,
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
self._voxel_tensor_list[0].shape[1],
b,
h,
n,
query_sphere.requires_grad,
False,
use_batch_env,
False,
True,
False,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_sphere_collision(
query_sphere,
collision_query_buffer,
weight,
activation_distance=activation_distance,
env_query_idx=env_query_idx,
return_loss=return_loss,
)
d_val = dist.view(d_prim.shape) + d_prim
return d_val
def get_swept_sphere_distance(
self,
query_sphere,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
speed_dt: torch.Tensor,
sweep_steps: int,
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
sum_collisions: bool = True,
):
"""
Computes the signed distance via analytic function
Args:
tensor_sphere: b, n, 4
"""
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
return super().get_swept_sphere_distance(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
b, h, n, _ = query_sphere.shape
use_batch_env = True
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
dist = SdfSweptSphereVoxel.apply(
query_sphere,
collision_query_buffer.voxel_collision_buffer.distance_buffer,
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
speed_dt,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
self._voxel_tensor_list[0].shape[1],
b,
h,
n,
sweep_steps,
enable_speed_metric,
query_sphere.requires_grad,
True,
use_batch_env,
return_loss,
sum_collisions,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_swept_sphere_distance(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
sum_collisions=sum_collisions,
)
d_val = dist.view(d_prim.shape) + d_prim
return d_val
def get_swept_sphere_collision(
self,
query_sphere,
collision_query_buffer: CollisionQueryBuffer,
weight: torch.Tensor,
activation_distance: torch.Tensor,
speed_dt: torch.Tensor,
sweep_steps: int,
enable_speed_metric=False,
env_query_idx: Optional[torch.Tensor] = None,
return_loss=False,
):
"""
Computes the signed distance via analytic function
Args:
tensor_sphere: b, n, 4
"""
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
return super().get_swept_sphere_collision(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
)
if return_loss:
raise ValueError("cannot return loss for classify, use get_swept_sphere_distance")
b, h, n, _ = query_sphere.shape
use_batch_env = True
if env_query_idx is None:
use_batch_env = False
env_query_idx = self._env_n_voxels
dist = SdfSweptSphereVoxel.apply(
query_sphere,
collision_query_buffer.voxel_collision_buffer.distance_buffer,
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
weight,
activation_distance,
self.max_distance,
speed_dt,
self._voxel_tensor_list[3],
self._voxel_tensor_list[0],
self._voxel_tensor_list[1],
self._voxel_tensor_list[2],
self._env_n_voxels,
env_query_idx,
self._voxel_tensor_list[0].shape[1],
b,
h,
n,
sweep_steps,
enable_speed_metric,
query_sphere.requires_grad,
False,
use_batch_env,
return_loss,
True,
)
if (
"primitive" not in self.collision_types
or not self.collision_types["primitive"]
or "mesh" not in self.collision_types
or not self.collision_types["mesh"]
):
return dist
d_prim = super().get_swept_sphere_collision(
query_sphere,
collision_query_buffer,
weight=weight,
env_query_idx=env_query_idx,
sweep_steps=sweep_steps,
activation_distance=activation_distance,
speed_dt=speed_dt,
enable_speed_metric=enable_speed_metric,
return_loss=return_loss,
)
d_val = dist.view(d_prim.shape) + d_prim
return d_val
def clear_cache(self):
if self._voxel_tensor_list is not None:
self._voxel_tensor_list[2][:] = 0
self._voxel_tensor_list[-1][:] = -1.0 * self.max_distance
self._env_n_voxels[:] = 0

View File

@@ -18,6 +18,7 @@ import warp as wp
# CuRobo
from curobo.curobolib.kinematics import rotation_matrix_to_quaternion
from curobo.util.logger import log_error
from curobo.util.torch_utils import get_torch_jit_decorator
from curobo.util.warp import init_warp
@@ -27,11 +28,11 @@ def transform_points(
if out_points is None:
out_points = torch.zeros((points.shape[0], 3), device=points.device, dtype=points.dtype)
if out_gp is None:
out_gp = torch.zeros((position.shape[0], 3), device=position.device)
out_gp = torch.zeros((position.shape[0], 3), device=position.device, dtype=points.dtype)
if out_gq is None:
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device)
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device, dtype=points.dtype)
if out_gpt is None:
out_gpt = torch.zeros((points.shape[0], 3), device=position.device)
out_gpt = torch.zeros((points.shape[0], 3), device=position.device, dtype=points.dtype)
out_points = TransformPoint.apply(
position, quaternion, points, out_points, out_gp, out_gq, out_gpt
)
@@ -46,18 +47,20 @@ def batch_transform_points(
(points.shape[0], points.shape[1], 3), device=points.device, dtype=points.dtype
)
if out_gp is None:
out_gp = torch.zeros((position.shape[0], 3), device=position.device)
out_gp = torch.zeros((position.shape[0], 3), device=position.device, dtype=points.dtype)
if out_gq is None:
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device)
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device, dtype=points.dtype)
if out_gpt is None:
out_gpt = torch.zeros((points.shape[0], points.shape[1], 3), device=position.device)
out_gpt = torch.zeros(
(points.shape[0], points.shape[1], 3), device=position.device, dtype=points.dtype
)
out_points = BatchTransformPoint.apply(
position, quaternion, points, out_points, out_gp, out_gq, out_gpt
)
return out_points
@torch.jit.script
@get_torch_jit_decorator()
def get_inv_transform(w_rot_c, w_trans_c):
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
c_rot_w = w_rot_c.transpose(-1, -2)
@@ -65,7 +68,7 @@ def get_inv_transform(w_rot_c, w_trans_c):
return c_rot_w, c_trans_w
@torch.jit.script
@get_torch_jit_decorator()
def transform_point_inverse(point, rot, trans):
# type: (Tensor, Tensor, Tensor) -> Tensor

View File

@@ -11,6 +11,7 @@
from __future__ import annotations
# Standard Library
import math
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Union
@@ -564,6 +565,68 @@ class PointCloud(Obstacle):
return new_spheres
@dataclass
class VoxelGrid(Obstacle):
dims: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0])
voxel_size: float = 0.02 # meters
feature_tensor: Optional[torch.Tensor] = None
xyzr_tensor: Optional[torch.Tensor] = None
feature_dtype: torch.dtype = torch.float32
def __post_init__(self):
if self.feature_tensor is not None:
self.feature_dtype = self.feature_tensor.dtype
def create_xyzr_tensor(
self, transform_to_origin: bool = False, tensor_args: TensorDeviceType = TensorDeviceType()
):
bounds = self.dims
low = [-bounds[0] / 2, -bounds[1] / 2, -bounds[2] / 2]
high = [bounds[0] / 2, bounds[1] / 2, bounds[2] / 2]
trange = [h - l for l, h in zip(low, high)]
x = torch.linspace(
low[0], high[0], int(math.floor(trange[0] / self.voxel_size)), device=tensor_args.device
)
y = torch.linspace(
low[1], high[1], int(math.floor(trange[1] / self.voxel_size)), device=tensor_args.device
)
z = torch.linspace(
low[2], high[2], int(math.floor(trange[2] / self.voxel_size)), device=tensor_args.device
)
w, l, h = x.shape[0], y.shape[0], z.shape[0]
xyz = (
torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3)
)
if transform_to_origin:
pose = Pose.from_list(self.pose, tensor_args=tensor_args)
xyz = pose.transform_points(xyz.contiguous())
r = torch.zeros_like(xyz[:, 0:1]) + (self.voxel_size * 0.5)
xyzr = torch.cat([xyz, r], dim=1)
return xyzr
def get_occupied_voxels(self, feature_threshold: Optional[float] = None):
if feature_threshold is None:
feature_threshold = -1.0 * self.voxel_size
if self.xyzr_tensor is None or self.feature_tensor is None:
log_error("Feature tensor or xyzr tensor is empty")
xyzr = self.xyzr_tensor.clone()
xyzr[:, 3] = self.feature_tensor
occupied = xyzr[self.feature_tensor > feature_threshold]
return occupied
def clone(self):
return VoxelGrid(
name=self.name,
pose=self.pose.copy(),
dims=self.dims.copy(),
feature_tensor=self.feature_tensor.clone() if self.feature_tensor is not None else None,
xyzr_tensor=self.xyzr_tensor.clone() if self.xyzr_tensor is not None else None,
feature_dtype=self.feature_dtype,
voxel_size=self.voxel_size,
)
@dataclass
class WorldConfig(Sequence):
"""Representation of World for use in CuRobo."""
@@ -586,25 +649,13 @@ class WorldConfig(Sequence):
#: BloxMap obstacle.
blox: Optional[List[BloxMap]] = None
voxel: Optional[List[VoxelGrid]] = None
#: List of all obstacles in world.
objects: Optional[List[Obstacle]] = None
def __post_init__(self):
# create objects list:
if self.objects is None:
self.objects = []
if self.sphere is not None:
self.objects += self.sphere
if self.cuboid is not None:
self.objects += self.cuboid
if self.capsule is not None:
self.objects += self.capsule
if self.mesh is not None:
self.objects += self.mesh
if self.blox is not None:
self.objects += self.blox
if self.cylinder is not None:
self.objects += self.cylinder
if self.sphere is None:
self.sphere = []
if self.cuboid is None:
@@ -617,6 +668,18 @@ class WorldConfig(Sequence):
self.cylinder = []
if self.blox is None:
self.blox = []
if self.voxel is None:
self.voxel = []
if self.objects is None:
self.objects = (
self.sphere
+ self.cuboid
+ self.capsule
+ self.mesh
+ self.cylinder
+ self.blox
+ self.voxel
)
def __len__(self):
return len(self.objects)
@@ -632,6 +695,7 @@ class WorldConfig(Sequence):
capsule=self.capsule.copy() if self.capsule is not None else None,
cylinder=self.cylinder.copy() if self.cylinder is not None else None,
blox=self.blox.copy() if self.blox is not None else None,
voxel=self.voxel.copy() if self.voxel is not None else None,
)
@staticmethod
@@ -642,6 +706,7 @@ class WorldConfig(Sequence):
mesh = None
blox = None
cylinder = None
voxel = None
# load yaml:
if "cuboid" in data_dict.keys():
cuboid = [Cuboid(name=x, **data_dict["cuboid"][x]) for x in data_dict["cuboid"]]
@@ -655,6 +720,8 @@ class WorldConfig(Sequence):
cylinder = [Cylinder(name=x, **data_dict["cylinder"][x]) for x in data_dict["cylinder"]]
if "blox" in data_dict.keys():
blox = [BloxMap(name=x, **data_dict["blox"][x]) for x in data_dict["blox"]]
if "voxel" in data_dict.keys():
voxel = [VoxelGrid(name=x, **data_dict["voxel"][x]) for x in data_dict["voxel"]]
return WorldConfig(
cuboid=cuboid,
@@ -663,6 +730,7 @@ class WorldConfig(Sequence):
cylinder=cylinder,
mesh=mesh,
blox=blox,
voxel=voxel,
)
# load world config as obbs: convert all types to obbs
@@ -688,6 +756,10 @@ class WorldConfig(Sequence):
if current_world.mesh is not None and len(current_world.mesh) > 0:
mesh_obb = [x.get_cuboid() for x in current_world.mesh]
if current_world.voxel is not None and len(current_world.voxel) > 0:
log_error("VoxelGrid cannot be converted to obb world")
return WorldConfig(
cuboid=cuboid_obb + sphere_obb + capsule_obb + cylinder_obb + mesh_obb + blox_obb
)
@@ -714,6 +786,8 @@ class WorldConfig(Sequence):
for i in range(len(current_world.blox)):
if current_world.blox[i].mesh is not None:
blox_obb.append(current_world.blox[i].get_mesh(process=process))
if current_world.voxel is not None and len(current_world.voxel) > 0:
log_error("VoxelGrid cannot be converted to mesh world")
return WorldConfig(
mesh=current_world.mesh
@@ -750,6 +824,7 @@ class WorldConfig(Sequence):
return WorldConfig(
mesh=current_world.mesh + sphere_obb + capsule_obb + cylinder_obb + blox_obb,
cuboid=cuboid_obb,
voxel=current_world.voxel,
)
@staticmethod
@@ -822,6 +897,8 @@ class WorldConfig(Sequence):
self.cylinder.append(obstacle)
elif isinstance(obstacle, Capsule):
self.capsule.append(obstacle)
elif isinstance(obstacle, VoxelGrid):
self.voxel.append(obstacle)
else:
ValueError("Obstacle type not supported")
self.objects.append(obstacle)

View File

@@ -33,6 +33,7 @@ from curobo.types.base import TensorDeviceType
from curobo.types.robot import JointState, RobotConfig, State
from curobo.util.logger import log_info, log_warn
from curobo.util.sample_lib import HaltonGenerator
from curobo.util.torch_utils import get_torch_jit_decorator
from curobo.util.trajectory import InterpolateType, get_interpolated_trajectory
from curobo.util_file import (
get_robot_configs_path,
@@ -1029,7 +1030,7 @@ class GraphPlanBase(GraphConfig):
pass
@torch.jit.script
@get_torch_jit_decorator(dynamic=True)
def get_unique_nodes(dist_node: torch.Tensor, nodes: torch.Tensor, node_distance: float):
node_flag = dist_node <= node_distance
dist_node[node_flag] = 0.0
@@ -1042,7 +1043,7 @@ def get_unique_nodes(dist_node: torch.Tensor, nodes: torch.Tensor, node_distance
return unique_nodes, n_inv
@torch.jit.script
@get_torch_jit_decorator(force_jit=True, dynamic=True)
def add_new_nodes_jit(
nodes, new_nodes, flag, cat_buffer, path, idx, i: int, dof: int
) -> Tuple[torch.Tensor, torch.Tensor, int]:
@@ -1066,7 +1067,7 @@ def add_new_nodes_jit(
return path, node_set, new_nodes.shape[0]
@torch.jit.script
@get_torch_jit_decorator(force_jit=True, dynamic=True)
def add_all_nodes_jit(
nodes, cat_buffer, path, i: int, dof: int
) -> Tuple[torch.Tensor, torch.Tensor, int]:
@@ -1084,20 +1085,20 @@ def add_all_nodes_jit(
return path, node_set, nodes.shape[0]
@torch.jit.script
@get_torch_jit_decorator(force_jit=True, dynamic=True)
def compute_distance_norm_jit(pt, batch_pts, distance_weight):
vec = (batch_pts - pt) * distance_weight
dist = torch.norm(vec, dim=-1)
return dist
@torch.jit.script
@get_torch_jit_decorator(dynamic=True)
def compute_distance_jit(pt, batch_pts, distance_weight):
vec = (batch_pts - pt) * distance_weight
return vec
@torch.jit.script
@get_torch_jit_decorator(dynamic=True)
def compute_rotation_frame_jit(
x_start: torch.Tensor, x_goal: torch.Tensor, rot_frame_col: torch.Tensor
) -> torch.Tensor:
@@ -1114,7 +1115,7 @@ def compute_rotation_frame_jit(
return C
@torch.jit.script
@get_torch_jit_decorator(force_jit=True, dynamic=True)
def biased_vertex_projection_jit(
x_start,
x_goal,
@@ -1144,7 +1145,7 @@ def biased_vertex_projection_jit(
return x_samples
@torch.jit.script
@get_torch_jit_decorator(force_jit=True, dynamic=True)
def cat_xc_jit(x, n: int):
c = x[:, 0:1] * 0.0
xc_search = torch.cat((x, c), dim=1)[:n, :]

View File

@@ -20,12 +20,13 @@ import torch.autograd.profiler as profiler
# CuRobo
from curobo.curobolib.opt import LBFGScu
from curobo.opt.newton.newton_base import NewtonOptBase, NewtonOptConfig
from curobo.util.logger import log_warn
from curobo.util.logger import log_info
from curobo.util.torch_utils import get_torch_jit_decorator
# kernel for l-bfgs:
# @torch.jit.script
def compute_step_direction(
@get_torch_jit_decorator()
def jit_lbfgs_compute_step_direction(
alpha_buffer,
rho_buffer,
y_buffer,
@@ -35,6 +36,8 @@ def compute_step_direction(
epsilon: float,
stable_mode: bool = True,
):
grad_q = grad_q.transpose(-1, -2)
# m = 15 (int)
# y_buffer, s_buffer: m x b x 175
# q, grad_q: b x 175
@@ -60,12 +63,39 @@ def compute_step_direction(
return -1.0 * r
@get_torch_jit_decorator()
def jit_lbfgs_update_buffers(
q, grad_q, s_buffer, y_buffer, rho_buffer, x_0, grad_0, stable_mode: bool
):
grad_q = grad_q.transpose(-1, -2)
q = q.unsqueeze(-1)
y = grad_q - grad_0
s = q - x_0
rho = 1 / (y.transpose(-1, -2) @ s)
if stable_mode:
rho = torch.nan_to_num(rho, 0.0, 0.0, 0.0)
s_buffer[0] = s
s_buffer[:] = torch.roll(s_buffer, -1, dims=0)
y_buffer[0] = y
y_buffer[:] = torch.roll(y_buffer, -1, dims=0) # .copy_(y_buff)
rho_buffer[0] = rho
rho_buffer[:] = torch.roll(rho_buffer, -1, dims=0)
x_0.copy_(q)
grad_0.copy_(grad_q)
return s_buffer, y_buffer, rho_buffer, x_0, grad_0
@dataclass
class LBFGSOptConfig(NewtonOptConfig):
history: int = 10
epsilon: float = 0.01
use_cuda_kernel: bool = True
stable_mode: bool = True
use_shared_buffers_kernel: bool = True
def __post_init__(self):
return super().__post_init__()
@@ -77,11 +107,15 @@ class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
if config is not None:
LBFGSOptConfig.__init__(self, **vars(config))
NewtonOptBase.__init__(self)
if self.d_opt >= 1024 or self.history > 15:
log_warn("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>15")
if (
self.d_opt >= 1024
or self.history > 31
or ((self.d_opt * self.history + 33) * 4 >= 48000)
):
log_info("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>15")
self.use_cuda_kernel = False
if self.history >= self.d_opt:
log_warn("LBFGS: history >= d_opt, reducing history to d_opt-1")
if self.history > self.d_opt:
log_info("LBFGS: history >= d_opt, reducing history to d_opt-1")
self.history = self.d_opt - 1
@profiler.record_function("lbfgs/reset")
@@ -130,7 +164,7 @@ class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
def _get_step_direction(self, cost, q, grad_q):
if self.use_cuda_kernel:
with profiler.record_function("lbfgs/fused"):
dq = LBFGScu._call_cuda(
dq = LBFGScu.apply(
self.step_q_buffer,
self.rho_buffer,
self.y_buffer,
@@ -141,13 +175,14 @@ class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
self.grad_0,
self.epsilon,
self.stable_mode,
self.use_shared_buffers_kernel,
)
else:
grad_q = grad_q.transpose(-1, -2)
q = q.unsqueeze(-1)
self._update_buffers(q, grad_q)
dq = compute_step_direction(
dq = jit_lbfgs_compute_step_direction(
self.alpha_buffer,
self.rho_buffer,
self.y_buffer,
@@ -177,6 +212,23 @@ class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
return -1.0 * r
def _update_buffers(self, q, grad_q):
if True:
self.s_buffer, self.y_buffer, self.rho_buffer, self.x_0, self.grad_0 = (
jit_lbfgs_update_buffers(
q,
grad_q,
self.s_buffer,
self.y_buffer,
self.rho_buffer,
self.x_0,
self.grad_0,
self.stable_mode,
)
)
return
grad_q = grad_q.transpose(-1, -2)
q = q.unsqueeze(-1)
y = grad_q - self.grad_0
s = q - self.x_0
rho = 1 / (y.transpose(-1, -2) @ s)

View File

@@ -26,6 +26,7 @@ from curobo.opt.opt_base import Optimizer, OptimizerConfig
from curobo.rollout.dynamics_model.integration_utils import build_fd_matrix
from curobo.types.base import TensorDeviceType
from curobo.types.tensor import T_BDOF, T_BHDOF_float, T_BHValue_float, T_BValue_float, T_HDOF_float
from curobo.util.torch_utils import get_torch_jit_decorator
class LineSearchType(Enum):
@@ -108,6 +109,7 @@ class NewtonOptBase(Optimizer, NewtonOptConfig):
self.action_horizon, device=self.tensor_args.device, dtype=self.tensor_args.dtype
).unsqueeze(0)
self._temporal_mat += eye_mat
self.rollout_fn.sum_horizon = True
def reset_cuda_graph(self):
if self.cu_opt_graph is not None:
@@ -222,11 +224,14 @@ class NewtonOptBase(Optimizer, NewtonOptConfig):
self.n_problems * self.num_particles, self.action_horizon, self.rollout_fn.d_action
)
trajectories = self.rollout_fn(x_in) # x_n = (batch*line_search_scale) x horizon x d_action
cost = torch.sum(
trajectories.costs.view(self.n_problems, self.num_particles, self.horizon),
dim=-1,
keepdim=True,
)
if len(trajectories.costs.shape) == 2:
cost = torch.sum(
trajectories.costs.view(self.n_problems, self.num_particles, self.horizon),
dim=-1,
keepdim=True,
)
else:
cost = trajectories.costs.view(self.n_problems, self.num_particles, 1)
g_x = cost.backward(gradient=self.l_vec, retain_graph=False)
g_x = x_n.grad.detach()
return (
@@ -542,7 +547,7 @@ class NewtonOptBase(Optimizer, NewtonOptConfig):
)
@torch.jit.script
@get_torch_jit_decorator()
def get_x_set_jit(step_vec, x, alpha_list, action_lows, action_highs):
# step_direction = step_direction.detach()
x_set = torch.clamp(x.unsqueeze(-2) + alpha_list * step_vec, action_lows, action_highs)
@@ -550,7 +555,7 @@ def get_x_set_jit(step_vec, x, alpha_list, action_lows, action_highs):
return x_set
@torch.jit.script
@get_torch_jit_decorator()
def _armijo_line_search_tail_jit(c, g_x, step_direction, c_1, alpha_list, c_idx, x_set, d_opt):
c_0 = c[:, 0:1]
g_0 = g_x[:, 0:1]
@@ -581,7 +586,7 @@ def _armijo_line_search_tail_jit(c, g_x, step_direction, c_1, alpha_list, c_idx,
return (best_x, best_c, best_grad)
@torch.jit.script
@get_torch_jit_decorator()
def _wolfe_search_tail_jit(c, g_x, x_set, m, d_opt: int):
b, h, _ = x_set.shape
g_x = g_x.view(b * h, -1)
@@ -593,7 +598,7 @@ def _wolfe_search_tail_jit(c, g_x, x_set, m, d_opt: int):
return (best_x, best_c, best_grad)
@torch.jit.script
@get_torch_jit_decorator()
def scale_action(dx, action_step_max):
scale_value = torch.max(torch.abs(dx) / action_step_max, dim=-1, keepdim=True)[0]
scale_value = torch.clamp(scale_value, 1.0)
@@ -601,7 +606,7 @@ def scale_action(dx, action_step_max):
return dx_scaled
@torch.jit.script
@get_torch_jit_decorator()
def check_convergence(
best_iteration: torch.Tensor, current_iteration: torch.Tensor, last_best: int
) -> bool:

View File

@@ -16,7 +16,15 @@ from typing import Optional
import torch
# CuRobo
from curobo.opt.particle.parallel_mppi import CovType, ParallelMPPI, ParallelMPPIConfig
from curobo.opt.particle.parallel_mppi import (
CovType,
ParallelMPPI,
ParallelMPPIConfig,
Trajectory,
jit_blend_mean,
)
from curobo.opt.particle.particle_opt_base import SampleMode
from curobo.util.torch_utils import get_torch_jit_decorator
@dataclass
@@ -38,14 +46,72 @@ class ParallelES(ParallelMPPI, ParallelESConfig):
)
# get the new means from here
# use Evolutionary Strategy Mean here:
new_mean = jit_blend_mean(self.mean_action, new_mean, self.step_size_mean)
return new_mean
def _exp_util_from_costs(self, costs):
total_costs = self._compute_total_cost(costs)
w = self._exp_util(total_costs)
return w
def _exp_util(self, total_costs):
w = calc_exp(total_costs)
return w
def _compute_mean_covariance(self, costs, actions):
w = self._exp_util_from_costs(costs)
w = w.unsqueeze(-1).unsqueeze(-1)
new_mean = self._compute_mean(w, actions)
new_cov = self._compute_covariance(w, actions)
self._update_cov_scale(new_cov)
@torch.jit.script
return new_mean, new_cov
@torch.no_grad()
def _update_distribution(self, trajectories: Trajectory):
costs = trajectories.costs
actions = trajectories.actions
# Let's reshape to n_problems now:
# first find the means before doing exponential utility:
# Update best action
if self.sample_mode == SampleMode.BEST:
w = self._exp_util_from_costs(costs)
best_idx = torch.argmax(w, dim=-1)
self.best_traj.copy_(actions[self.problem_col, best_idx])
if self.store_rollouts and self.visual_traj is not None:
total_costs = self._compute_total_cost(costs)
vis_seq = getattr(trajectories.state, self.visual_traj)
top_values, top_idx = torch.topk(total_costs, 20, dim=1)
self.top_values = top_values
self.top_idx = top_idx
top_trajs = torch.index_select(vis_seq, 0, top_idx[0])
for i in range(1, top_idx.shape[0]):
trajs = torch.index_select(
vis_seq, 0, top_idx[i] + (self.particles_per_problem * i)
)
top_trajs = torch.cat((top_trajs, trajs), dim=0)
if self.top_trajs is None or top_trajs.shape != self.top_trajs:
self.top_trajs = top_trajs
else:
self.top_trajs.copy_(top_trajs)
if not self.update_cov:
w = w.unsqueeze(-1).unsqueeze(-1)
new_mean = self._compute_mean(w, actions)
else:
new_mean, new_cov = self._compute_mean_covariance(costs, actions)
self.cov_action.copy_(new_cov)
self.mean_action.copy_(new_mean)
@get_torch_jit_decorator()
def calc_exp(
total_costs,
):
@@ -58,7 +124,7 @@ def calc_exp(
return w
@torch.jit.script
@get_torch_jit_decorator()
def compute_es_mean(
w, actions, mean_action, full_inv_cov, num_particles: int, learning_rate: float
):

View File

@@ -33,6 +33,7 @@ from curobo.types.robot import State
from curobo.util.logger import log_info
from curobo.util.sample_lib import HaltonSampleLib, SampleConfig, SampleLib
from curobo.util.tensor_util import copy_tensor
from curobo.util.torch_utils import get_torch_jit_decorator
class BaseActionType(Enum):
@@ -187,11 +188,43 @@ class ParallelMPPI(ParticleOptBase, ParallelMPPIConfig):
# w = torch.softmax((-1.0 / self.beta) * total_costs, dim=-1)
return w
def _exp_util_from_costs(self, costs):
w = jit_calculate_exp_util_from_costs(costs, self.gamma_seq, self.beta)
return w
def _compute_mean(self, w, actions):
# get the new means from here
new_mean = torch.sum(w * actions, dim=-3)
new_mean = jit_blend_mean(self.mean_action, new_mean, self.step_size_mean)
return new_mean
def _compute_mean_covariance(self, costs, actions):
if self.cov_type == CovType.FULL_A:
log_error("Not implemented")
if self.cov_type == CovType.DIAG_A:
new_mean, new_cov, new_scale_tril = jit_mean_cov_diag_a(
costs,
actions,
self.gamma_seq,
self.mean_action,
self.cov_action,
self.step_size_mean,
self.step_size_cov,
self.kappa,
self.beta,
)
self.scale_tril.copy_(new_scale_tril)
# self._update_cov_scale(new_cov)
else:
w = self._exp_util_from_costs(costs)
w = w.unsqueeze(-1).unsqueeze(-1)
new_mean = self._compute_mean(w, actions)
new_cov = self._compute_covariance(w, actions)
self._update_cov_scale(new_cov)
return new_mean, new_cov
def _compute_covariance(self, w, actions):
if not self.update_cov:
return
@@ -200,27 +233,13 @@ class ParallelMPPI(ParticleOptBase, ParallelMPPIConfig):
if self.cov_type == CovType.SIGMA_I:
delta_actions = actions - self.mean_action.unsqueeze(-3)
# weighted_delta = w * (delta ** 2).T
# cov_update = torch.ean(torch.sum(weighted_delta.T, dim=0))
# print(cov_update.shape, self.cov_action)
weighted_delta = w * (delta_actions**2)
cov_update = torch.mean(
torch.sum(torch.sum(weighted_delta, dim=-2), dim=-1), dim=-1, keepdim=True
)
# raise NotImplementedError("Need to implement covariance update of form sigma*I")
elif self.cov_type == CovType.DIAG_A:
# Diagonal covariance of size AxA
# n, b, h, d = delta_actions.shape
# delta_actions = delta_actions.view(n,b,h*d)
# weighted_delta = w * (delta_actions**2)
# weighted_delta =
# sum across horizon and mean across particles:
# cov_update = torch.diag(torch.mean(torch.sum(weighted_delta.T , dim=0), dim=0))
# cov_update = torch.mean(torch.sum(weighted_delta, dim=-2), dim=-2).unsqueeze(
# -2
# ) # .expand(-1,-1,-1)
cov_update = jit_diag_a_cov_update(w, actions, self.mean_action)
elif self.cov_type == CovType.FULL_A:
@@ -241,19 +260,22 @@ class ParallelMPPI(ParticleOptBase, ParallelMPPIConfig):
else:
raise ValueError("Unidentified covariance type in update_distribution")
cov_update = jit_blend_cov(self.cov_action, cov_update, self.step_size_cov, self.kappa)
return cov_update
def _update_cov_scale(self):
def _update_cov_scale(self, new_cov=None):
if new_cov is None:
new_cov = self.cov_action
if not self.update_cov:
return
if self.cov_type == CovType.SIGMA_I:
self.scale_tril = torch.sqrt(self.cov_action)
self.scale_tril = torch.sqrt(new_cov)
elif self.cov_type == CovType.DIAG_A:
self.scale_tril.copy_(torch.sqrt(self.cov_action))
self.scale_tril.copy_(torch.sqrt(new_cov))
elif self.cov_type == CovType.FULL_A:
self.scale_tril = matrix_cholesky(self.cov_action)
self.scale_tril = matrix_cholesky(new_cov)
elif self.cov_type == CovType.FULL_HA:
raise NotImplementedError
@@ -263,44 +285,44 @@ class ParallelMPPI(ParticleOptBase, ParallelMPPIConfig):
costs = trajectories.costs
actions = trajectories.actions
total_costs = self._compute_total_cost(costs)
# Let's reshape to n_problems now:
# first find the means before doing exponential utility:
w = self._exp_util(total_costs)
# Update best action
if self.sample_mode == SampleMode.BEST:
best_idx = torch.argmax(w, dim=-1)
self.best_traj.copy_(actions[self.problem_col, best_idx])
with profiler.record_function("mppi/get_best"):
if self.store_rollouts and self.visual_traj is not None:
vis_seq = getattr(trajectories.state, self.visual_traj)
top_values, top_idx = torch.topk(total_costs, 20, dim=1)
self.top_values = top_values
self.top_idx = top_idx
top_trajs = torch.index_select(vis_seq, 0, top_idx[0])
for i in range(1, top_idx.shape[0]):
trajs = torch.index_select(
vis_seq, 0, top_idx[i] + (self.particles_per_problem * i)
)
top_trajs = torch.cat((top_trajs, trajs), dim=0)
if self.top_trajs is None or top_trajs.shape != self.top_trajs:
self.top_trajs = top_trajs
else:
self.top_trajs.copy_(top_trajs)
# Update best action
if self.sample_mode == SampleMode.BEST:
w = self._exp_util_from_costs(costs)
best_idx = torch.argmax(w, dim=-1)
self.best_traj.copy_(actions[self.problem_col, best_idx])
with profiler.record_function("mppi/store_rollouts"):
w = w.unsqueeze(-1).unsqueeze(-1)
if self.store_rollouts and self.visual_traj is not None:
total_costs = self._compute_total_cost(costs)
vis_seq = getattr(trajectories.state, self.visual_traj)
top_values, top_idx = torch.topk(total_costs, 20, dim=1)
self.top_values = top_values
self.top_idx = top_idx
top_trajs = torch.index_select(vis_seq, 0, top_idx[0])
for i in range(1, top_idx.shape[0]):
trajs = torch.index_select(
vis_seq, 0, top_idx[i] + (self.particles_per_problem * i)
)
top_trajs = torch.cat((top_trajs, trajs), dim=0)
if self.top_trajs is None or top_trajs.shape != self.top_trajs:
self.top_trajs = top_trajs
else:
self.top_trajs.copy_(top_trajs)
new_mean = self._compute_mean(w, actions)
# print(new_mean)
if self.update_cov:
cov_update = self._compute_covariance(w, actions)
new_cov = jit_blend_cov(self.cov_action, cov_update, self.step_size_cov, self.kappa)
if not self.update_cov:
w = self._exp_util_from_costs(costs)
w = w.unsqueeze(-1).unsqueeze(-1)
new_mean = self._compute_mean(w, actions)
else:
new_mean, new_cov = self._compute_mean_covariance(costs, actions)
self.cov_action.copy_(new_cov)
self._update_cov_scale()
new_mean = jit_blend_mean(self.mean_action, new_mean, self.step_size_mean)
self.mean_action.copy_(new_mean)
@torch.no_grad()
@@ -591,20 +613,28 @@ class ParallelMPPI(ParticleOptBase, ParallelMPPIConfig):
return super().generate_rollouts(init_act)
@torch.jit.script
@get_torch_jit_decorator()
def jit_calculate_exp_util(beta: float, total_costs):
w = torch.softmax((-1.0 / beta) * total_costs, dim=-1)
return w
@torch.jit.script
@get_torch_jit_decorator()
def jit_calculate_exp_util_from_costs(costs, gamma_seq, beta: float):
cost_seq = gamma_seq * costs
cost_seq = torch.sum(cost_seq, dim=-1, keepdim=False) / gamma_seq[..., 0]
w = torch.softmax((-1.0 / beta) * cost_seq, dim=-1)
return w
@get_torch_jit_decorator()
def jit_compute_total_cost(gamma_seq, costs):
cost_seq = gamma_seq * costs
cost_seq = torch.sum(cost_seq, dim=-1, keepdim=False) / gamma_seq[..., 0]
return cost_seq
@torch.jit.script
@get_torch_jit_decorator()
def jit_diag_a_cov_update(w, actions, mean_action):
delta_actions = actions - mean_action.unsqueeze(-3)
@@ -616,13 +646,35 @@ def jit_diag_a_cov_update(w, actions, mean_action):
return cov_update
@torch.jit.script
@get_torch_jit_decorator()
def jit_blend_cov(cov_action, cov_update, step_size_cov: float, kappa: float):
new_cov = (1.0 - step_size_cov) * cov_action + step_size_cov * cov_update + kappa
return new_cov
@torch.jit.script
@get_torch_jit_decorator()
def jit_blend_mean(mean_action, new_mean, step_size_mean: float):
mean_update = (1.0 - step_size_mean) * mean_action + step_size_mean * new_mean
return mean_update
@get_torch_jit_decorator()
def jit_mean_cov_diag_a(
costs,
actions,
gamma_seq,
mean_action,
cov_action,
step_size_mean: float,
step_size_cov: float,
kappa: float,
beta: float,
):
w = jit_calculate_exp_util_from_costs(costs, gamma_seq, beta)
w = w.unsqueeze(-1).unsqueeze(-1)
new_mean = torch.sum(w * actions, dim=-3)
new_mean = jit_blend_mean(mean_action, new_mean, step_size_mean)
cov_update = jit_diag_a_cov_update(w, actions, mean_action)
new_cov = jit_blend_cov(cov_action, cov_update, step_size_cov, kappa)
new_tril = torch.sqrt(new_cov)
return new_mean, new_cov, new_tril

View File

@@ -263,7 +263,7 @@ class ParticleOptBase(Optimizer, ParticleOptConfig):
trajectory.costs = trajectory.costs.view(
self.n_problems, self.particles_per_problem, self.horizon
)
with profiler.record_function("mpc/mppi/update_distribution"):
with profiler.record_function("mppi/update_distribution"):
self._update_distribution(trajectory)
if not self.use_cuda_graph and self.store_debug:
self.debug.append(self._get_action_seq(mode=self.sample_mode).clone())

View File

@@ -18,6 +18,7 @@ import torch.autograd.profiler as profiler
# CuRobo
from curobo.types.base import TensorDeviceType
from curobo.util.torch_utils import get_torch_jit_decorator
class SquashType(Enum):
@@ -74,7 +75,7 @@ def get_stomp_cov(
return cov, scale_tril
@torch.jit.script
@get_torch_jit_decorator()
def get_stomp_cov_jit(
horizon: int,
d_action: int,
@@ -245,7 +246,7 @@ def gaussian_kl(mean0, cov0, mean1, cov1, cov_type="full"):
return term1 + mahalanobis_dist + term3
# @torch.jit.script
# @get_torch_jit_decorator()
def cost_to_go(cost_seq, gamma_seq, only_first=False):
# type: (Tensor, Tensor, bool) -> Tensor

View File

@@ -40,7 +40,7 @@ from curobo.types.base import TensorDeviceType
from curobo.types.robot import CSpaceConfig, RobotConfig
from curobo.types.state import JointState
from curobo.util.logger import log_error, log_info, log_warn
from curobo.util.tensor_util import cat_sum
from curobo.util.tensor_util import cat_sum, cat_sum_horizon
@dataclass
@@ -104,10 +104,10 @@ class ArmCostConfig:
@dataclass
class ArmBaseConfig(RolloutConfig):
model_cfg: KinematicModelConfig
cost_cfg: ArmCostConfig
constraint_cfg: ArmCostConfig
convergence_cfg: ArmCostConfig
model_cfg: Optional[KinematicModelConfig] = None
cost_cfg: Optional[ArmCostConfig] = None
constraint_cfg: Optional[ArmCostConfig] = None
convergence_cfg: Optional[ArmCostConfig] = None
world_coll_checker: Optional[WorldCollision] = None
@staticmethod
@@ -322,7 +322,9 @@ class ArmBase(RolloutBase, ArmBaseConfig):
self.null_convergence = DistCost(self.convergence_cfg.null_space_cfg)
# set start state:
start_state = torch.randn((1, self.dynamics_model.d_state), **vars(self.tensor_args))
start_state = torch.randn(
(1, self.dynamics_model.d_state), **(self.tensor_args.as_torch_dict())
)
self._start_state = JointState(
position=start_state[:, : self.dynamics_model.d_dof],
velocity=start_state[:, : self.dynamics_model.d_dof],
@@ -366,9 +368,11 @@ class ArmBase(RolloutBase, ArmBaseConfig):
)
cost_list.append(coll_cost)
if return_list:
return cost_list
cost = cat_sum(cost_list)
if self.sum_horizon:
cost = cat_sum_horizon(cost_list)
else:
cost = cat_sum(cost_list)
return cost
def constraint_fn(

View File

@@ -10,7 +10,7 @@
#
# Standard Library
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, List, Optional
# Third Party
import torch
@@ -29,8 +29,9 @@ from curobo.types.base import TensorDeviceType
from curobo.types.robot import RobotConfig
from curobo.types.tensor import T_BValue_float, T_BValue_int
from curobo.util.helpers import list_idx_if_not_none
from curobo.util.logger import log_error, log_info
from curobo.util.tensor_util import cat_max, cat_sum
from curobo.util.logger import log_error, log_info, log_warn
from curobo.util.tensor_util import cat_max
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from .arm_base import ArmBase, ArmBaseConfig, ArmCostConfig
@@ -145,7 +146,7 @@ class ArmReacherConfig(ArmBaseConfig):
)
@torch.jit.script
@get_torch_jit_decorator()
def _compute_g_dist_jit(rot_err_norm, goal_dist):
# goal_cost = goal_cost.view(cost.shape)
# rot_err_norm = rot_err_norm.view(cost.shape)
@@ -319,7 +320,12 @@ class ArmReacher(ArmBase, ArmReacherConfig):
g_dist,
)
cost_list.append(z_vel)
cost = cat_sum(cost_list)
with profiler.record_function("cat_sum"):
if self.sum_horizon:
cost = cat_sum_horizon_reacher(cost_list)
else:
cost = cat_sum_reacher(cost_list)
return cost
def convergence_fn(
@@ -466,3 +472,15 @@ class ArmReacher(ArmBase, ArmReacherConfig):
)
for x in pose_costs
]
@get_torch_jit_decorator()
def cat_sum_reacher(tensor_list: List[torch.Tensor]):
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=0)
return cat_tensor
@get_torch_jit_decorator()
def cat_sum_horizon_reacher(tensor_list: List[torch.Tensor]):
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=(0, -1))
return cat_tensor

View File

@@ -21,6 +21,7 @@ import warp as wp
from curobo.cuda_robot_model.types import JointLimits
from curobo.types.robot import JointState
from curobo.types.tensor import T_DOF
from curobo.util.torch_utils import get_torch_jit_decorator
from curobo.util.warp import init_warp
# Local Folder
@@ -267,7 +268,7 @@ class BoundCost(CostBase, BoundCostConfig):
return super().update_dt(dt)
@torch.jit.script
@get_torch_jit_decorator()
def forward_bound_cost(p, lower_bounds, upper_bounds, weight):
# c = weight * torch.sum(torch.nn.functional.relu(torch.max(lower_bounds - p, p - upper_bounds)), dim=-1)
@@ -281,7 +282,7 @@ def forward_bound_cost(p, lower_bounds, upper_bounds, weight):
return c
@torch.jit.script
@get_torch_jit_decorator()
def forward_all_bound_cost(
p,
v,

View File

@@ -18,6 +18,7 @@ import torch
import warp as wp
# CuRobo
from curobo.util.torch_utils import get_torch_jit_decorator
from curobo.util.warp import init_warp
# Local Folder
@@ -41,32 +42,32 @@ class DistCostConfig(CostConfig):
return super().__post_init__()
@torch.jit.script
@get_torch_jit_decorator()
def L2_DistCost_jit(vec_weight, disp_vec):
return torch.norm(vec_weight * disp_vec, p=2, dim=-1, keepdim=False)
@torch.jit.script
@get_torch_jit_decorator()
def fwd_SQL2_DistCost_jit(vec_weight, disp_vec):
return torch.sum(torch.square(vec_weight * disp_vec), dim=-1, keepdim=False)
@torch.jit.script
@get_torch_jit_decorator()
def fwd_L1_DistCost_jit(vec_weight, disp_vec):
return torch.sum(torch.abs(vec_weight * disp_vec), dim=-1, keepdim=False)
@torch.jit.script
@get_torch_jit_decorator()
def L2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
return torch.norm(weight * vec_weight * (g_vec - c_vec), p=2, dim=-1, keepdim=False)
@torch.jit.script
@get_torch_jit_decorator()
def fwd_SQL2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
return torch.sum(torch.square(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
@torch.jit.script
@get_torch_jit_decorator()
def fwd_L1_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
return torch.sum(torch.abs(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)

View File

@@ -19,6 +19,8 @@ import torch
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollision
from curobo.rollout.cost.cost_base import CostBase, CostConfig
from curobo.rollout.dynamics_model.integration_utils import interpolate_kernel, sum_matrix
from curobo.util.logger import log_info
from curobo.util.torch_utils import get_torch_jit_decorator
@dataclass
@@ -48,7 +50,11 @@ class PrimitiveCollisionCostConfig(CostConfig):
#: post optimization interpolation to not hit any obstacles.
activation_distance: Union[torch.Tensor, float] = 0.0
#: Setting this flag to true will sum the distance colliding obstacles.
sum_collisions: bool = True
#: Setting this flag to true will sum the distance across spheres of the robot.
#: Setting to False will only take the max distance
sum_distance: bool = True
def __post_init__(self):
@@ -103,6 +109,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
self._collision_query_buffer.update_buffer_shape(
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
)
if not self.sum_distance:
log_info("sum_distance=False will be slower than sum_distance=True")
self.return_loss = True
dist = self.sweep_check_fn(
robot_spheres_in,
self._collision_query_buffer,
@@ -115,9 +124,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
return_loss=self.return_loss,
)
if self.classify:
cost = weight_collision(dist, self.weight, self.sum_distance)
cost = weight_collision(dist, self.sum_distance)
else:
cost = weight_distance(dist, self.weight, self.sum_distance)
cost = weight_distance(dist, self.sum_distance)
return cost
def sweep_fn(self, robot_spheres_in, env_query_idx: Optional[torch.Tensor] = None):
@@ -140,6 +149,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
self._collision_query_buffer.update_buffer_shape(
sampled_spheres.shape, self.tensor_args, self.world_coll_checker.collision_types
)
if not self.sum_distance:
log_info("sum_distance=False will be slower than sum_distance=True")
self.return_loss = True
dist = self.coll_check_fn(
sampled_spheres.contiguous(),
self._collision_query_buffer,
@@ -151,9 +163,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
dist = dist.view(batch_size, new_horizon, n_spheres)
if self.classify:
cost = weight_sweep_collision(self.int_sum_mat, dist, self.weight, self.sum_distance)
cost = weight_sweep_collision(self.int_sum_mat, dist, self.sum_distance)
else:
cost = weight_sweep_distance(self.int_sum_mat, dist, self.weight, self.sum_distance)
cost = weight_sweep_distance(self.int_sum_mat, dist, self.sum_distance)
return cost
@@ -161,6 +173,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
self._collision_query_buffer.update_buffer_shape(
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
)
if not self.sum_distance:
log_info("sum_distance=False will be slower than sum_distance=True")
self.return_loss = True
dist = self.coll_check_fn(
robot_spheres_in,
self._collision_query_buffer,
@@ -168,12 +183,13 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
env_query_idx=env_query_idx,
activation_distance=self.activation_distance,
return_loss=self.return_loss,
sum_collisions=self.sum_collisions,
)
if self.classify:
cost = weight_collision(dist, self.weight, self.sum_distance)
cost = weight_collision(dist, self.sum_distance)
else:
cost = weight_distance(dist, self.weight, self.sum_distance)
cost = weight_distance(dist, self.sum_distance)
return cost
def update_dt(self, dt: Union[float, torch.Tensor]):
@@ -184,31 +200,43 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
return self._collision_query_buffer.get_gradient_buffer()
@torch.jit.script
def weight_sweep_distance(int_mat, dist, weight, sum_cost: bool):
dist = torch.sum(dist, dim=-1)
@get_torch_jit_decorator()
def weight_sweep_distance(int_mat, dist, sum_cost: bool):
if sum_cost:
dist = torch.sum(dist, dim=-1)
else:
dist = torch.max(dist, dim=-1)[0]
dist = dist @ int_mat
return dist
@torch.jit.script
def weight_sweep_collision(int_mat, dist, weight, sum_cost: bool):
dist = torch.sum(dist, dim=-1)
@get_torch_jit_decorator()
def weight_sweep_collision(int_mat, dist, sum_cost: bool):
if sum_cost:
dist = torch.sum(dist, dim=-1)
else:
dist = torch.max(dist, dim=-1)[0]
dist = torch.where(dist > 0, dist + 1.0, dist)
dist = dist @ int_mat
return dist
@torch.jit.script
def weight_distance(dist, weight, sum_cost: bool):
@get_torch_jit_decorator()
def weight_distance(dist, sum_cost: bool):
if sum_cost:
dist = torch.sum(dist, dim=-1)
else:
dist = torch.max(dist, dim=-1)[0]
return dist
@torch.jit.script
def weight_collision(dist, weight, sum_cost: bool):
@get_torch_jit_decorator()
def weight_collision(dist, sum_cost: bool):
if sum_cost:
dist = torch.sum(dist, dim=-1)
else:
dist = torch.max(dist, dim=-1)[0]
dist = torch.where(dist > 0, dist + 1.0, dist)
return dist

View File

@@ -17,6 +17,7 @@ import torch
# CuRobo
from curobo.rollout.dynamics_model.kinematic_model import TimeTrajConfig
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from .cost_base import CostBase, CostConfig
@@ -72,7 +73,7 @@ class StopCost(CostBase, StopCostConfig):
return cost
@torch.jit.script
@get_torch_jit_decorator()
def velocity_cost(vels, weight, max_vel):
vel_abs = torch.abs(vels)
vel_abs = torch.nn.functional.relu(vel_abs - max_vel[: vels.shape[1]])

View File

@@ -13,11 +13,14 @@
# Third Party
import torch
# CuRobo
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from .cost_base import CostBase, CostConfig
@torch.jit.script
@get_torch_jit_decorator()
def st_cost(ee_pos_batch, vec_weight, weight):
ee_plus_one = torch.roll(ee_pos_batch, 1, dims=1)

View File

@@ -11,11 +11,14 @@
# Third Party
import torch
# CuRobo
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from .cost_base import CostBase
@torch.jit.script
@get_torch_jit_decorator()
def squared_sum(cost: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
# return weight * torch.square(torch.linalg.norm(cost, dim=-1, ord=1))
# return weight * torch.sum(torch.square(cost), dim=-1)
@@ -24,7 +27,7 @@ def squared_sum(cost: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
return torch.sum(torch.square(cost) * weight, dim=-1)
@torch.jit.script
@get_torch_jit_decorator()
def run_squared_sum(
cost: torch.Tensor, weight: torch.Tensor, run_weight: torch.Tensor
) -> torch.Tensor:
@@ -35,13 +38,13 @@ def run_squared_sum(
# return torch.sum(torch.square(cost), dim=-1) * weight * run_weight
@torch.jit.script
@get_torch_jit_decorator()
def backward_squared_sum(cost_vec, w):
return 2.0 * w * cost_vec # * g_out.unsqueeze(-1)
# return w * g_out.unsqueeze(-1)
@torch.jit.script
@get_torch_jit_decorator()
def backward_run_squared_sum(cost_vec, w, r_w):
return 2.0 * w * r_w.unsqueeze(-1) * cost_vec # * g_out.unsqueeze(-1)
# return w * r_w.unsqueeze(-1) * cost_vec * g_out.unsqueeze(-1)

View File

@@ -25,6 +25,7 @@ from curobo.curobolib.tensor_step import (
)
from curobo.types.base import TensorDeviceType
from curobo.types.robot import JointState
from curobo.util.torch_utils import get_torch_jit_decorator
def build_clique_matrix(horizon, dt, device="cpu", dtype=torch.float32):
@@ -154,7 +155,7 @@ def build_start_state_mask(horizon, tensor_args: TensorDeviceType):
return mask, n_mask
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_jerk(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix=None):
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Optional[Tensor]) -> Tensor
@@ -176,7 +177,7 @@ def tensor_step_jerk(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_m
return state_seq
# @torch.jit.script
# @get_torch_jit_decorator()
def euler_integrate(q_0, u, diag_dt, integrate_matrix):
# q_new = q_0 + torch.matmul(integrate_matrix, torch.matmul(diag_dt, u))
q_new = q_0 + torch.matmul(integrate_matrix, u)
@@ -184,7 +185,7 @@ def euler_integrate(q_0, u, diag_dt, integrate_matrix):
return q_new
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_acc(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix=None):
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Optional[Tensor]) -> Tensor
# This is batch,n_dof
@@ -207,7 +208,7 @@ def tensor_step_acc(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_ma
return state_seq
@torch.jit.script
@get_torch_jit_decorator()
def jit_tensor_step_pos_clique_contiguous(pos_act, start_position, mask, n_mask, fd_1, fd_2, fd_3):
state_position = (start_position.unsqueeze(1).transpose(1, 2) @ mask.transpose(0, 1)) + (
pos_act.transpose(1, 2) @ n_mask.transpose(0, 1)
@@ -222,7 +223,7 @@ def jit_tensor_step_pos_clique_contiguous(pos_act, start_position, mask, n_mask,
return state_position, state_vel, state_acc, state_jerk
@torch.jit.script
@get_torch_jit_decorator()
def jit_tensor_step_pos_clique(pos_act, start_position, mask, n_mask, fd_1, fd_2, fd_3):
state_position = mask @ start_position.unsqueeze(1) + n_mask @ pos_act
state_vel = fd_1 @ state_position
@@ -231,7 +232,7 @@ def jit_tensor_step_pos_clique(pos_act, start_position, mask, n_mask, fd_1, fd_2
return state_position, state_vel, state_acc, state_jerk
@torch.jit.script
@get_torch_jit_decorator()
def jit_backward_pos_clique(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2, fd_3):
p_grad = (
grad_p
@@ -247,7 +248,7 @@ def jit_backward_pos_clique(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2,
return u_grad
@torch.jit.script
@get_torch_jit_decorator()
def jit_backward_pos_clique_contiguous(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2, fd_3):
p_grad = grad_p + (
grad_j.transpose(-1, -2) @ fd_3
@@ -532,7 +533,7 @@ class CliqueTensorStepIdxCentralDifferenceKernel(torch.autograd.Function):
start_position,
start_velocity,
start_acceleration,
start_idx,
start_idx.contiguous(),
traj_dt,
out_position.shape[0],
out_position.shape[1],
@@ -750,7 +751,7 @@ class AccelerationTensorStepIdxKernel(torch.autograd.Function):
return u_grad, None, None, None, None, None, None, None, None, None, None
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_pos_clique(
state: JointState,
act: torch.Tensor,
@@ -786,7 +787,7 @@ def step_acc_semi_euler(state, act, diag_dt, n_dofs, integrate_matrix):
return state_seq
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_acc_semi_euler(
state, act, state_seq, diag_dt, integrate_matrix, integrate_matrix_pos
):
@@ -806,7 +807,7 @@ def tensor_step_acc_semi_euler(
return state_seq
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_vel(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix):
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Tensor) -> Tensor
@@ -830,7 +831,7 @@ def tensor_step_vel(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_ma
return state_seq
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_pos(state, act, state_seq, fd_matrix):
# This is batch,n_dof
state_seq.position[:, 0, :] = state.position
@@ -850,7 +851,7 @@ def tensor_step_pos(state, act, state_seq, fd_matrix):
return state_seq
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_pos_ik(act, state_seq):
state_seq.position = act
return state_seq
@@ -869,7 +870,7 @@ def tensor_linspace(start_tensor, end_tensor, steps=10):
def sum_matrix(h, int_steps, tensor_args):
sum_mat = torch.zeros(((h - 1) * int_steps, h), **vars(tensor_args))
sum_mat = torch.zeros(((h - 1) * int_steps, h), **(tensor_args.as_torch_dict()))
for i in range(h - 1):
sum_mat[i * int_steps : i * int_steps + int_steps, i] = 1.0
# hack:

View File

@@ -19,6 +19,7 @@ import torch
# CuRobo
from curobo.types.base import TensorDeviceType
from curobo.types.robot import JointState
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from .integration_utils import (
@@ -544,7 +545,7 @@ class TensorStepPositionCliqueKernel(TensorStepBase):
return new_signal
@torch.jit.script
@get_torch_jit_decorator(force_jit=True)
def filter_signal_jit(signal, kernel):
b, h, dof = signal.shape

View File

@@ -36,6 +36,7 @@ from curobo.util.helpers import list_idx_if_not_none
from curobo.util.logger import log_info
from curobo.util.sample_lib import HaltonGenerator
from curobo.util.tensor_util import copy_tensor
from curobo.util.torch_utils import get_torch_jit_decorator
@dataclass
@@ -298,9 +299,9 @@ class Goal(Sequence):
if self.goal_pose is not None:
self.goal_pose = self.goal_pose.to(tensor_args)
if self.goal_state is not None:
self.goal_state = self.goal_state.to(**vars(tensor_args))
self.goal_state = self.goal_state.to(**(tensor_args.as_torch_dict()))
if self.current_state is not None:
self.current_state = self.current_state.to(**vars(tensor_args))
self.current_state = self.current_state.to(**(tensor_args.as_torch_dict()))
return self
def copy_(self, goal: Goal, update_idx_buffers: bool = True):
@@ -350,6 +351,7 @@ class Goal(Sequence):
if ref_buffer is not None:
ref_buffer = ref_buffer.copy_(buffer)
else:
log_info("breaking reference")
ref_buffer = buffer.clone()
return ref_buffer
@@ -414,6 +416,7 @@ class Goal(Sequence):
@dataclass
class RolloutConfig:
tensor_args: TensorDeviceType
sum_horizon: bool = False
class RolloutBase:
@@ -578,7 +581,7 @@ class RolloutBase:
return q_js
@torch.jit.script
@get_torch_jit_decorator()
def tensor_repeat_seeds(tensor, num_seeds: int):
a = (
tensor.view(tensor.shape[0], 1, tensor.shape[-1])

View File

@@ -19,7 +19,10 @@ import torch
@dataclass(frozen=True)
class TensorDeviceType:
device: torch.device = torch.device("cuda", 0)
dtype: torch.float32 = torch.float32
dtype: torch.dtype = torch.float32
collision_geometry_dtype: torch.dtype = torch.float32
collision_gradient_dtype: torch.dtype = torch.float32
collision_distance_dtype: torch.dtype = torch.float32
@staticmethod
def from_basic(device: str, dev_id: int):
@@ -36,3 +39,6 @@ class TensorDeviceType:
def cpu(self):
return TensorDeviceType(device=torch.device("cpu"), dtype=self.dtype)
def as_torch_dict(self):
return {"device": self.device, "dtype": self.dtype}

View File

@@ -39,7 +39,7 @@ class CameraObservation:
resolution: Optional[List[int]] = None
pose: Optional[Pose] = None
intrinsics: Optional[torch.Tensor] = None
timestamp: float = 0.0
timestamp: Optional[torch.Tensor] = None
def filter_depth(self, distance: float = 0.01):
self.depth_image = torch.where(self.depth_image < distance, 0, self.depth_image)
@@ -62,6 +62,8 @@ class CameraObservation:
self.projection_rays.copy_(new_data.projection_rays)
if self.pose is not None:
self.pose.copy_(new_data.pose)
if self.timestamp is not None:
self.timestamp.copy_(new_data.timestamp)
self.resolution = new_data.resolution
@record_function("camera/clone")
@@ -73,13 +75,41 @@ class CameraObservation:
intrinsics=self.intrinsics.clone() if self.intrinsics is not None else None,
resolution=self.resolution,
pose=self.pose.clone() if self.pose is not None else None,
timestamp=self.timestamp.clone() if self.timestamp is not None else None,
image_segmentation=(
self.image_segmentation.clone() if self.image_segmentation is not None else None
),
projection_matrix=(
self.projection_matrix.clone() if self.projection_matrix is not None else None
),
projection_rays=(
self.projection_rays.clone() if self.projection_rays is not None else None
),
name=self.name,
)
def to(self, device: torch.device):
if self.rgb_image is not None:
self.rgb_image = self.rgb_image.to(device=device)
if self.depth_image is not None:
self.depth_image = self.depth_image.to(device=device)
self.rgb_image = self.rgb_image.to(device=device) if self.rgb_image is not None else None
self.depth_image = (
self.depth_image.to(device=device) if self.depth_image is not None else None
)
self.image_segmentation = (
self.image_segmentation.to(device=device)
if self.image_segmentation is not None
else None
)
self.projection_matrix = (
self.projection_matrix.to(device=device) if self.projection_matrix is not None else None
)
self.projection_rays = (
self.projection_rays.to(device=device) if self.projection_rays is not None else None
)
self.intrinsics = self.intrinsics.to(device=device) if self.intrinsics is not None else None
self.timestamp = self.timestamp.to(device=device) if self.timestamp is not None else None
self.pose = self.pose.to(device=device) if self.pose is not None else None
return self
def get_pointcloud(self):
@@ -114,3 +144,56 @@ class CameraObservation:
self.projection_rays = project_rays
self.projection_rays.copy_(project_rays)
def stack(self, new_observation: CameraObservation, dim: int = 0):
rgb_image = (
torch.stack((self.rgb_image, new_observation.rgb_image), dim=dim)
if self.rgb_image is not None
else None
)
depth_image = (
torch.stack((self.depth_image, new_observation.depth_image), dim=dim)
if self.depth_image is not None
else None
)
image_segmentation = (
torch.stack((self.image_segmentation, new_observation.image_segmentation), dim=dim)
if self.image_segmentation is not None
else None
)
projection_matrix = (
torch.stack((self.projection_matrix, new_observation.projection_matrix), dim=dim)
if self.projection_matrix is not None
else None
)
projection_rays = (
torch.stack((self.projection_rays, new_observation.projection_rays), dim=dim)
if self.projection_rays is not None
else None
)
resolution = self.resolution
pose = self.pose.stack(new_observation.pose) if self.pose is not None else None
intrinsics = (
torch.stack((self.intrinsics, new_observation.intrinsics), dim=dim)
if self.intrinsics is not None
else None
)
timestamp = (
torch.stack((self.timestamp, new_observation.timestamp), dim=dim)
if self.timestamp is not None
else None
)
return CameraObservation(
name=self.name,
rgb_image=rgb_image,
depth_image=depth_image,
image_segmentation=image_segmentation,
projection_matrix=projection_matrix,
projection_rays=projection_rays,
resolution=resolution,
pose=pose,
intrinsics=intrinsics,
timestamp=timestamp,
)

View File

@@ -12,7 +12,7 @@ from __future__ import annotations
# Standard Library
from dataclasses import dataclass
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Union
# Third Party
import numpy as np
@@ -35,6 +35,7 @@ from curobo.types.base import TensorDeviceType
from curobo.util.helpers import list_idx_if_not_none
from curobo.util.logger import log_error, log_info, log_warn
from curobo.util.tensor_util import clone_if_not_none, copy_tensor
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from .tensor import T_BPosition, T_BQuaternion, T_BRotation
@@ -263,8 +264,17 @@ class Pose(Sequence):
# rotation=clone_if_not_none(self.rotation),
)
def to(self, tensor_args: TensorDeviceType):
t_type = vars(tensor_args)
def to(
self,
tensor_args: Optional[TensorDeviceType] = None,
device: Optional[torch.device] = None,
):
if tensor_args is None and device is None:
log_error("Pose.to() requires tensor_args or device")
if tensor_args is not None:
t_type = vars(tensor_args.as_torch_dict())
else:
t_type = {"device": device}
if self.position is not None:
self.position = self.position.to(**t_type)
if self.quaternion is not None:
@@ -338,7 +348,7 @@ class Pose(Sequence):
return p_distance, quat_distance
def angular_distance(self, other_pose: Pose, use_phi3: bool = False):
"""This function computes the angular distance \phi_3.
"""This function computes the angular distance phi_3.
See Huynh, Du Q. "Metrics for 3D rotations: Comparison and analysis." Journal of Mathematical
Imaging and Vision 35 (2009): 155-164.
@@ -461,9 +471,9 @@ def quat_multiply(q1, q2, q_res):
return q_res
@torch.jit.script
@get_torch_jit_decorator()
def angular_distance_phi3(goal_quat, current_quat):
"""This function computes the angular distance \phi_3.
"""This function computes the angular distance phi_3.
See Huynh, Du Q. "Metrics for 3D rotations: Comparison and analysis." Journal of Mathematical
Imaging and Vision 35 (2009): 155-164.
@@ -524,7 +534,7 @@ class OrientationError(Function):
return None, grad_mul, None
@torch.jit.script
@get_torch_jit_decorator()
def normalize_quaternion(in_quaternion: torch.Tensor) -> torch.Tensor:
k = torch.sign(in_quaternion[..., 0:1])
# NOTE: torch sign returns 0 as sign value when value is 0.0

View File

@@ -30,6 +30,7 @@ from curobo.util.tensor_util import (
fd_tensor,
tensor_repeat_seeds,
)
from curobo.util.torch_utils import get_torch_jit_decorator
@dataclass
@@ -211,10 +212,10 @@ class JointState(State):
j = None
v = a = None
max_idx = 0
if isinstance(idx, List):
idx = torch.as_tensor(idx, device=self.position.device, dtype=torch.long)
if isinstance(idx, int):
max_idx = idx
elif isinstance(idx, List):
max_idx = max(idx)
elif isinstance(idx, torch.Tensor):
max_idx = torch.max(idx)
if max_idx >= self.position.shape[0]:
@@ -223,31 +224,19 @@ class JointState(State):
+ " index out of range, current state is of length "
+ str(self.position.shape)
)
p = self.position[idx]
if self.velocity is not None:
if max_idx >= self.velocity.shape[0]:
raise ValueError(
str(max_idx)
+ " index out of range, current velocity is of length "
+ str(self.velocity.shape)
)
v = self.velocity[idx]
if self.acceleration is not None:
if max_idx >= self.acceleration.shape[0]:
raise ValueError(
str(max_idx)
+ " index out of range, current acceleration is of length "
+ str(self.acceleration.shape)
)
a = self.acceleration[idx]
if self.jerk is not None:
if max_idx >= self.jerk.shape[0]:
raise ValueError(
str(max_idx)
+ " index out of range, current jerk is of length "
+ str(self.jerk.shape)
)
j = self.jerk[idx]
if isinstance(idx, int):
p, v, a, j = jit_get_index_int(
self.position, self.velocity, self.acceleration, self.jerk, idx
)
elif isinstance(idx, torch.Tensor):
p, v, a, j = jit_get_index(
self.position, self.velocity, self.acceleration, self.jerk, idx
)
else:
p, v, a, j = fn_get_index(
self.position, self.velocity, self.acceleration, self.jerk, idx
)
return JointState(p, v, a, joint_names=self.joint_names, jerk=j)
def __len__(self):
@@ -514,6 +503,88 @@ class JointState(State):
jerk = self.jerk * (dt**3)
return JointState(self.position, vel, acc, self.joint_names, jerk, self.tensor_args)
def scale_by_dt(self, dt: torch.Tensor, new_dt: torch.Tensor):
vel, acc, jerk = jit_js_scale(self.velocity, self.acceleration, self.jerk, dt, new_dt)
return JointState(self.position, vel, acc, self.joint_names, jerk, self.tensor_args)
@property
def shape(self):
return self.position.shape
@get_torch_jit_decorator()
def jit_js_scale(
vel: Union[None, torch.Tensor],
acc: Union[None, torch.Tensor],
jerk: Union[None, torch.Tensor],
dt: torch.Tensor,
new_dt: torch.Tensor,
):
scale_dt = dt / new_dt
if vel is not None:
vel = vel * scale_dt
if acc is not None:
acc = acc * scale_dt * scale_dt
if jerk is not None:
jerk = jerk * scale_dt * scale_dt * scale_dt
return vel, acc, jerk
@get_torch_jit_decorator()
def jit_get_index(
position: torch.Tensor,
velocity: Union[torch.Tensor, None],
acc: Union[torch.Tensor, None],
jerk: Union[torch.Tensor, None],
idx: torch.Tensor,
):
position = position[idx]
if velocity is not None:
velocity = velocity[idx]
if acc is not None:
acc = acc[idx]
if jerk is not None:
jerk = jerk[idx]
return position, velocity, acc, jerk
def fn_get_index(
position: torch.Tensor,
velocity: Union[torch.Tensor, None],
acc: Union[torch.Tensor, None],
jerk: Union[torch.Tensor, None],
idx: torch.Tensor,
):
position = position[idx]
if velocity is not None:
velocity = velocity[idx]
if acc is not None:
acc = acc[idx]
if jerk is not None:
jerk = jerk[idx]
return position, velocity, acc, jerk
@get_torch_jit_decorator()
def jit_get_index_int(
position: torch.Tensor,
velocity: Union[torch.Tensor, None],
acc: Union[torch.Tensor, None],
jerk: Union[torch.Tensor, None],
idx: int,
):
position = position[idx]
if velocity is not None:
velocity = velocity[idx]
if acc is not None:
acc = acc[idx]
if jerk is not None:
jerk = jerk[idx]
return position, velocity, acc, jerk

View File

@@ -9,54 +9,24 @@
# its affiliates is strictly prohibited.
#
""" This module contains aliases for structured Tensors, improving readability."""
# Third Party
import torch
# CuRobo
from curobo.util.logger import log_warn
try:
# Third Party
from torchtyping import TensorType
except ImportError:
log_warn("torchtyping could not be imported, falling back to basic types")
TensorType = None
# Third Party
import torch
b_dof = ["batch", "dof"]
b_value = ["batch", "value"]
bh_value = ["batch", "horizon", "value"]
bh_dof = ["batch", "horizon", "dof"]
h_dof = ["horizon", "dof"]
T_DOF = torch.Tensor #: Tensor of shape [degrees of freedom]
T_BDOF = torch.Tensor #: Tensor of shape [batch, degrees of freedom]
T_BHDOF_float = torch.Tensor #: Tensor of shape [batch, horizon, degrees of freedom]
T_HDOF_float = torch.Tensor #: Tensor of shape [horizon, degrees of freedom]
if TensorType is not None:
T_DOF = TensorType[tuple(["dof"] + [float])]
T_BDOF = TensorType[tuple(b_dof + [float])]
T_BValue_float = TensorType[tuple(b_value + [float])]
T_BHValue_float = TensorType[tuple(bh_value + [float])]
T_BValue_bool = TensorType[tuple(b_value + [bool])]
T_BValue_int = TensorType[tuple(b_value + [int])]
T_BValue_float = torch.Tensor #: Float Tensor of shape [batch, 1].
T_BHValue_float = torch.Tensor #: Float Tensor of shape [batch, horizon, 1].
T_BValue_bool = torch.Tensor #: Bool Tensor of shape [batch, horizon, 1].
T_BValue_int = torch.Tensor #: Int Tensor of shape [batch, horizon, 1].
T_BPosition = TensorType["batch", "xyz":3, float]
T_BQuaternion = TensorType["batch", "wxyz":4, float]
T_BRotation = TensorType["batch", 3, 3, float]
T_Position = TensorType["xyz":3, float]
T_Quaternion = TensorType["wxyz":4, float]
T_Rotation = TensorType[3, 3, float]
T_BHDOF_float = TensorType[tuple(bh_dof + [float])]
T_HDOF_float = TensorType[tuple(h_dof + [float])]
else:
T_DOF = torch.Tensor
T_BDOF = torch.Tensor
T_BValue_float = torch.Tensor
T_BHValue_float = torch.Tensor
T_BValue_bool = torch.Tensor
T_BValue_int = torch.Tensor
T_BPosition = torch.Tensor
T_BQuaternion = torch.Tensor
T_BRotation = torch.Tensor
T_Position = torch.Tensor
T_Quaternion = torch.Tensor
T_Rotation = torch.Tensor
T_BHDOF_float = torch.Tensor
T_HDOF_float = torch.Tensor
T_BPosition = torch.Tensor #: Tensor of shape [batch, 3].
T_BQuaternion = torch.Tensor #: Tensor of shape [batch, 4].
T_BRotation = torch.Tensor #: Tensor of shape [batch, 3,3].

View File

@@ -25,6 +25,7 @@ from torch.distributions.multivariate_normal import MultivariateNormal
# CuRobo
from curobo.types.base import TensorDeviceType
from curobo.util.logger import log_error, log_warn
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from ..opt.particle.particle_opt_utils import get_stomp_cov
@@ -511,7 +512,7 @@ class HaltonGenerator:
return gaussian_halton_samples
@torch.jit.script
@get_torch_jit_decorator()
def gaussian_transform(
uniform_samples: torch.Tensor, proj_mat: torch.Tensor, i_mat: torch.Tensor, std_dev: float
):

View File

@@ -14,6 +14,9 @@ from typing import List
# Third Party
import torch
# CuRobo
from curobo.util.torch_utils import get_torch_jit_decorator
def check_tensor_shapes(new_tensor: torch.Tensor, mem_tensor: torch.Tensor):
if not isinstance(mem_tensor, torch.Tensor):
@@ -65,13 +68,19 @@ def clone_if_not_none(x):
return None
@torch.jit.script
@get_torch_jit_decorator()
def cat_sum(tensor_list: List[torch.Tensor]):
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=0)
return cat_tensor
@torch.jit.script
@get_torch_jit_decorator()
def cat_sum_horizon(tensor_list: List[torch.Tensor]):
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=(0, -1))
return cat_tensor
@get_torch_jit_decorator()
def cat_max(tensor_list: List[torch.Tensor]):
cat_tensor = torch.max(torch.stack(tensor_list, dim=0), dim=0)[0]
return cat_tensor
@@ -85,7 +94,7 @@ def tensor_repeat_seeds(tensor, num_seeds):
)
@torch.jit.script
@get_torch_jit_decorator()
def fd_tensor(p: torch.Tensor, dt: torch.Tensor):
out = ((torch.roll(p, -1, -2) - p) * (1 / dt).unsqueeze(-1))[..., :-1, :]
return out

View File

@@ -8,12 +8,15 @@
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
import os
# Third Party
import torch
from packaging import version
# CuRobo
from curobo.util.logger import log_warn
from curobo.util.logger import log_info, log_warn
def find_first_idx(array, value, EQUAL=False):
@@ -31,13 +34,119 @@ def find_last_idx(array, value):
def is_cuda_graph_available():
if version.parse(torch.__version__) < version.parse("1.10"):
log_warn("WARNING: Disabling CUDA Graph as pytorch < 1.10")
log_warn("Disabling CUDA Graph as pytorch < 1.10")
return False
return True
def is_torch_compile_available():
force_compile = os.environ.get("CUROBO_TORCH_COMPILE_FORCE")
if force_compile is not None and bool(int(force_compile)):
return True
if version.parse(torch.__version__) < version.parse("2.0"):
log_warn("WARNING: Disabling compile as pytorch < 2.0")
log_info("Disabling torch.compile as pytorch < 2.0")
return False
env_variable = os.environ.get("CUROBO_TORCH_COMPILE_DISABLE")
if env_variable is None:
log_info("Environment variable for CUROBO_TORCH_COMPILE is not set, Disabling.")
return False
if bool(int(env_variable)):
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Disable")
return False
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Enable")
try:
torch.compile
except:
log_info("Could not find torch.compile, disabling Torch Compile.")
return False
try:
torch._dynamo
except:
log_info("Could not find torch._dynamo, disabling Torch Compile.")
return False
try:
# Third Party
import triton
except:
log_info("Could not find triton, disabling Torch Compile.")
return False
return True
def get_torch_compile_options() -> dict:
options = {}
if is_torch_compile_available():
# Third Party
from torch._inductor import config
torch._dynamo.config.suppress_errors = True
use_options = {
"max_autotune": True,
"use_mixed_mm": True,
"conv_1x1_as_mm": True,
"coordinate_descent_tuning": True,
"epilogue_fusion": False,
"coordinate_descent_check_all_directions": True,
"force_fuse_int_mm_with_mul": True,
"triton.cudagraphs": False,
"aggressive_fusion": True,
"split_reductions": False,
"worker_start_method": "spawn",
}
for k in use_options.keys():
if hasattr(config, k):
options[k] = use_options[k]
else:
log_info("Not found in torch.compile: " + k)
return options
def disable_torch_compile_global():
if is_torch_compile_available():
torch._dynamo.config.disable = True
return True
return False
def set_torch_compile_global_options():
if is_torch_compile_available():
# Third Party
from torch._inductor import config
torch._dynamo.config.suppress_errors = True
if hasattr(config, "conv_1x1_as_mm"):
torch._inductor.config.conv_1x1_as_mm = True
if hasattr(config, "coordinate_descent_tuning"):
torch._inductor.config.coordinate_descent_tuning = True
if hasattr(config, "epilogue_fusion"):
torch._inductor.config.epilogue_fusion = False
if hasattr(config, "coordinate_descent_check_all_directions"):
torch._inductor.config.coordinate_descent_check_all_directions = True
if hasattr(config, "force_fuse_int_mm_with_mul"):
torch._inductor.config.force_fuse_int_mm_with_mul = True
if hasattr(config, "use_mixed_mm"):
torch._inductor.config.use_mixed_mm = True
return True
return False
def get_torch_jit_decorator(
force_jit: bool = False, dynamic: bool = True, only_valid_for_compile: bool = False
):
if not force_jit and is_torch_compile_available():
return torch.compile(options=get_torch_compile_options(), dynamic=dynamic)
elif not only_valid_for_compile:
return torch.jit.script
else:
return empty_decorator
def empty_decorator(function):
return function

View File

@@ -25,6 +25,7 @@ from curobo.types.base import TensorDeviceType
from curobo.types.robot import JointState
from curobo.util.logger import log_error, log_info, log_warn
from curobo.util.sample_lib import bspline
from curobo.util.torch_utils import get_torch_jit_decorator
from curobo.util.warp_interpolation import get_cuda_linear_interpolation
@@ -114,7 +115,7 @@ def get_spline_interpolated_trajectory(raw_traj, des_horizon, degree=5):
for i in range(cpu_traj.shape[-1]):
retimed_traj[:, i] = bspline(cpu_traj[:, i], n=des_horizon, degree=degree)
retimed_traj = retimed_traj.to(**vars(tensor_args))
retimed_traj = retimed_traj.to(**(tensor_args.as_torch_dict()))
return retimed_traj
@@ -385,7 +386,7 @@ def get_interpolated_trajectory(
kind=kind,
last_step=des_horizon,
)
retimed_traj = retimed_traj.to(**vars(tensor_args))
retimed_traj = retimed_traj.to(**(tensor_args.as_torch_dict()))
out_traj_state.position[b, :interpolation_steps, :] = retimed_traj
out_traj_state.position[b, interpolation_steps:, :] = retimed_traj[
interpolation_steps - 1 : interpolation_steps, :
@@ -438,7 +439,39 @@ def linear_smooth(
return y_new
@torch.jit.script
@get_torch_jit_decorator()
def calculate_dt_fixed(
vel: torch.Tensor,
acc: torch.Tensor,
jerk: torch.Tensor,
max_vel: torch.Tensor,
max_acc: torch.Tensor,
max_jerk: torch.Tensor,
raw_dt: torch.Tensor,
interpolation_dt: float,
):
# compute scaled dt:
max_v_arr = torch.max(torch.abs(vel), dim=-2)[0] # output is batch, dof
max_acc_arr = torch.max(torch.abs(acc), dim=-2)[0]
max_jerk_arr = torch.max(torch.abs(jerk), dim=-2)[0]
vel_scale_dt = (max_v_arr) / (max_vel.view(1, max_v_arr.shape[-1])) # batch,dof
acc_scale_dt = max_acc_arr / (max_acc.view(1, max_acc_arr.shape[-1]))
jerk_scale_dt = max_jerk_arr / (max_jerk.view(1, max_jerk_arr.shape[-1]))
dt_score_vel = raw_dt * torch.max(vel_scale_dt, dim=-1)[0] # batch, 1
dt_score_acc = raw_dt * torch.sqrt((torch.max(acc_scale_dt, dim=-1)[0]))
dt_score_jerk = raw_dt * torch.pow((torch.max(jerk_scale_dt, dim=-1)[0]), 1 / 3)
dt_score = torch.maximum(dt_score_vel, dt_score_acc)
dt_score = torch.maximum(dt_score, dt_score_jerk)
dt_score = torch.clamp(dt_score, interpolation_dt, raw_dt)
# NOTE: this dt score is not dt, rather a scaling to convert velocity, acc, jerk that was
# computed with raw_dt to a new dt
return dt_score
@get_torch_jit_decorator(force_jit=True)
def calculate_dt(
vel: torch.Tensor,
acc: torch.Tensor,
@@ -470,7 +503,7 @@ def calculate_dt(
return dt_score
@torch.jit.script
@get_torch_jit_decorator(force_jit=True)
def calculate_dt_no_clamp(
vel: torch.Tensor,
acc: torch.Tensor,
@@ -497,7 +530,7 @@ def calculate_dt_no_clamp(
return dt_score
@torch.jit.script
@get_torch_jit_decorator()
def calculate_tsteps(
vel: torch.Tensor,
acc: torch.Tensor,
@@ -506,13 +539,15 @@ def calculate_tsteps(
max_vel: torch.Tensor,
max_acc: torch.Tensor,
max_jerk: torch.Tensor,
raw_dt: float,
raw_dt: torch.Tensor,
min_dt: float,
horizon: int,
optimize_dt: bool = True,
):
# compute scaled dt:
opt_dt = calculate_dt(vel, acc, jerk, max_vel, max_acc, max_jerk, raw_dt, interpolation_dt)
opt_dt = calculate_dt_fixed(
vel, acc, jerk, max_vel, max_acc, max_jerk, raw_dt, interpolation_dt
)
if not optimize_dt:
opt_dt[:] = raw_dt
traj_steps = (torch.ceil((horizon - 1) * ((opt_dt) / interpolation_dt))).to(dtype=torch.int32)

View File

@@ -11,6 +11,7 @@
# Standard Library
import os
import shutil
import sys
from typing import Dict, List
# Third Party

View File

@@ -19,13 +19,16 @@ from torch.profiler import record_function
# CuRobo
from curobo.cuda_robot_model.cuda_robot_model import CudaRobotModel
from curobo.geom.cv import get_projection_rays, project_depth_using_rays
from curobo.geom.types import PointCloud
from curobo.types.base import TensorDeviceType
from curobo.types.camera import CameraObservation
from curobo.types.math import Pose
from curobo.types.robot import RobotConfig
from curobo.types.state import JointState
from curobo.util.logger import log_error
from curobo.util.torch_utils import (
get_torch_compile_options,
get_torch_jit_decorator,
is_torch_compile_available,
)
from curobo.util_file import get_robot_configs_path, join_path, load_yaml
from curobo.wrap.model.robot_world import RobotWorld, RobotWorldConfig
@@ -36,6 +39,7 @@ class RobotSegmenter:
robot_world: RobotWorld,
distance_threshold: float = 0.05,
use_cuda_graph: bool = True,
ops_dtype: torch.dtype = torch.float32,
):
self._robot_world = robot_world
self._projection_rays = None
@@ -48,11 +52,12 @@ class RobotSegmenter:
self._use_cuda_graph = use_cuda_graph
self.tensor_args = robot_world.tensor_args
self.distance_threshold = distance_threshold
self._ops_dtype = ops_dtype
@staticmethod
def from_robot_file(
robot_file: Union[str, Dict],
collision_sphere_buffer: Optional[float],
collision_sphere_buffer: Optional[float] = None,
distance_threshold: float = 0.05,
use_cuda_graph: bool = True,
tensor_args: TensorDeviceType = TensorDeviceType(),
@@ -78,7 +83,7 @@ class RobotSegmenter:
def get_pointcloud_from_depth(self, camera_obs: CameraObservation):
if self._projection_rays is None:
self.update_camera_projection(camera_obs)
depth_image = camera_obs.depth_image
depth_image = camera_obs.depth_image.to(dtype=self._ops_dtype)
if len(depth_image.shape) == 2:
depth_image = depth_image.unsqueeze(0)
points = project_depth_using_rays(depth_image, self._projection_rays)
@@ -91,7 +96,7 @@ class RobotSegmenter:
intrinsics = intrinsics.unsqueeze(0)
project_rays = get_projection_rays(
camera_obs.depth_image.shape[-2], camera_obs.depth_image.shape[-1], intrinsics
)
).to(dtype=self._ops_dtype)
if self._projection_rays is None:
self._projection_rays = project_rays
@@ -157,8 +162,12 @@ class RobotSegmenter:
def _mask_op(self, camera_obs, q):
if len(q.shape) == 1:
q = q.unsqueeze(0)
robot_spheres = self._robot_world.get_kinematics(q).link_spheres_tensor
points = self.get_pointcloud_from_depth(camera_obs)
camera_to_robot = camera_obs.pose
points = points.to(dtype=torch.float32)
if self._out_points_buffer is None:
self._out_points_buffer = points.clone()
@@ -181,9 +190,9 @@ class RobotSegmenter:
out_points = points_in_robot_frame
dist = self._robot_world.get_point_robot_distance(out_points, q)
mask, filtered_image = mask_image(camera_obs.depth_image, dist, self.distance_threshold)
mask, filtered_image = mask_spheres_image(
camera_obs.depth_image, robot_spheres, out_points, self.distance_threshold
)
return mask, filtered_image
@@ -200,7 +209,7 @@ class RobotSegmenter:
return self.kinematics.base_link
@torch.jit.script
@get_torch_jit_decorator()
def mask_image(
image: torch.Tensor, distance: torch.Tensor, distance_threshold: float
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -212,3 +221,36 @@ def mask_image(
mask = torch.logical_and((image > 0.0), (distance > -distance_threshold))
filtered_image = torch.where(mask, 0, image)
return mask, filtered_image
@get_torch_jit_decorator()
def mask_spheres_image(
image: torch.Tensor,
link_spheres_tensor: torch.Tensor,
points: torch.Tensor,
distance_threshold: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
if link_spheres_tensor.shape[0] != 1:
assert link_spheres_tensor.shape[0] == points.shape[0]
if len(points.shape) == 2:
points = points.unsqueeze(0)
robot_spheres = link_spheres_tensor.view(link_spheres_tensor.shape[0], -1, 4).contiguous()
robot_spheres = robot_spheres.unsqueeze(-3)
robot_radius = robot_spheres[..., 3]
points = points.unsqueeze(-2)
sph_distance = -1 * (
torch.linalg.norm(points - robot_spheres[..., :3], dim=-1) - robot_radius
) # b, n_spheres
distance = torch.max(sph_distance, dim=-1)[0]
distance = distance.view(
image.shape[0],
image.shape[1],
image.shape[2],
)
mask = torch.logical_and((image > 0.0), (distance > -distance_threshold))
filtered_image = torch.where(mask, 0, image)
return mask, filtered_image

View File

@@ -36,6 +36,11 @@ from curobo.types.robot import RobotConfig
from curobo.types.state import JointState
from curobo.util.logger import log_error
from curobo.util.sample_lib import HaltonGenerator
from curobo.util.torch_utils import (
get_torch_compile_options,
get_torch_jit_decorator,
is_torch_compile_available,
)
from curobo.util.warp import init_warp
from curobo.util_file import get_robot_configs_path, get_world_configs_path, join_path, load_yaml
@@ -192,6 +197,9 @@ class RobotWorld(RobotWorldConfig):
def update_world(self, world_config: WorldConfig):
self.world_model.load_collision_model(world_config)
def clear_world_cache(self):
self.world_model.clear_cache()
def get_collision_distance(
self, x_sph: torch.Tensor, env_query_idx: Optional[torch.Tensor] = None
) -> torch.Tensor:
@@ -364,16 +372,7 @@ class RobotWorld(RobotWorldConfig):
if len(q.shape) == 1:
log_error("q should be of shape [b, dof]")
kin_state = self.get_kinematics(q)
b, n = None, None
if len(points.shape) == 3:
b, n, _ = points.shape
points = points.view(b * n, 3)
pt_distance = point_robot_distance(kin_state.link_spheres_tensor, points)
if b is not None:
pt_distance = pt_distance.view(b, n)
return pt_distance
def get_active_js(self, full_js: JointState):
@@ -382,27 +381,50 @@ class RobotWorld(RobotWorldConfig):
return out_js
@torch.jit.script
@get_torch_jit_decorator()
def sum_mask(d1, d2, d3):
d_total = d1 + d2 + d3
d_mask = d_total == 0.0
return d_mask.view(-1)
@torch.jit.script
@get_torch_jit_decorator()
def mask(d1, d2, d3):
d_total = d1 + d2 + d3
d_mask = d_total == 0.0
return d_mask
@torch.jit.script
@get_torch_jit_decorator()
def point_robot_distance(link_spheres_tensor, points):
robot_spheres = link_spheres_tensor.view(1, -1, 4).contiguous()
robot_radius = robot_spheres[:, :, 3]
points = points.unsqueeze(1)
sph_distance = (
torch.linalg.norm(points - robot_spheres[:, :, :3], dim=-1) - robot_radius
"""Compute distance between robot and points
Args:
link_spheres_tensor: [batch_robot, n_robot_spheres, 4]
points: [batch_points, n_points, 3]
Returns:
distance: [batch_points, n_points]
"""
if link_spheres_tensor.shape[0] != 1:
assert link_spheres_tensor.shape[0] == points.shape[0]
squeeze_shape = False
n = 1
if len(points.shape) == 2:
squeeze_shape = True
n, _ = points.shape
points = points.unsqueeze(0)
robot_spheres = link_spheres_tensor.view(link_spheres_tensor.shape[0], -1, 4).contiguous()
robot_spheres = robot_spheres.unsqueeze(-3)
robot_radius = robot_spheres[..., 3]
points = points.unsqueeze(-2)
sph_distance = -1 * (
torch.linalg.norm(points - robot_spheres[..., :3], dim=-1) - robot_radius
) # b, n_spheres
pt_distance = torch.max(-1 * sph_distance, dim=-1)[0]
pt_distance = torch.max(sph_distance, dim=-1)[0]
if squeeze_shape:
pt_distance = pt_distance.view(n)
return pt_distance

View File

@@ -20,6 +20,7 @@ import torch.autograd.profiler as profiler
# CuRobo
from curobo.types.robot import JointState
from curobo.types.tensor import T_DOF
from curobo.util.torch_utils import get_torch_jit_decorator
from curobo.util.trajectory import calculate_dt
@@ -32,7 +33,7 @@ class TrajEvaluatorConfig:
max_dt: float = 0.1
@torch.jit.script
@get_torch_jit_decorator()
def compute_path_length(vel, traj_dt, cspace_distance_weight):
pl = torch.sum(
torch.sum(torch.abs(vel) * traj_dt.unsqueeze(-1) * cspace_distance_weight, dim=-1), dim=-1
@@ -40,24 +41,25 @@ def compute_path_length(vel, traj_dt, cspace_distance_weight):
return pl
@torch.jit.script
@get_torch_jit_decorator()
def compute_path_length_cost(vel, cspace_distance_weight):
pl = torch.sum(torch.sum(torch.abs(vel) * cspace_distance_weight, dim=-1), dim=-1)
return pl
@torch.jit.script
@get_torch_jit_decorator()
def smooth_cost(abs_acc, abs_jerk, opt_dt):
# acc = torch.max(torch.max(abs_acc, dim=-1)[0], dim=-1)[0]
# jerk = torch.max(torch.max(abs_jerk, dim=-1)[0], dim=-1)[0]
jerk = torch.mean(torch.max(abs_jerk, dim=-1)[0], dim=-1)
mean_acc = torch.mean(torch.max(abs_acc, dim=-1)[0], dim=-1) # [0]
a = (jerk * 0.001) + 5.0 * opt_dt + (mean_acc * 0.01)
a = (jerk * 0.001) + 10.0 * opt_dt + (mean_acc * 0.01)
# a = (jerk * 0.001) + 50.0 * opt_dt + (mean_acc * 0.01)
return a
@torch.jit.script
@get_torch_jit_decorator()
def compute_smoothness(
vel: torch.Tensor,
acc: torch.Tensor,
@@ -104,7 +106,7 @@ def compute_smoothness(
return (acc_label, smooth_cost(abs_acc, abs_jerk, dt_score))
@torch.jit.script
@get_torch_jit_decorator()
def compute_smoothness_opt_dt(
vel, acc, jerk, max_vel: torch.Tensor, max_acc: float, max_jerk: float, opt_dt: torch.Tensor
):

View File

@@ -34,6 +34,7 @@ from curobo.types.robot import JointState, RobotConfig
from curobo.types.tensor import T_BDOF, T_BValue_bool, T_BValue_float
from curobo.util.logger import log_error, log_warn
from curobo.util.sample_lib import HaltonGenerator
from curobo.util.torch_utils import get_torch_jit_decorator
from curobo.util_file import (
get_robot_configs_path,
get_task_configs_path,
@@ -1010,7 +1011,7 @@ class IKSolver(IKSolverConfig):
]
@torch.jit.script
@get_torch_jit_decorator()
def get_success(
feasible,
position_error,
@@ -1028,7 +1029,7 @@ def get_success(
return success
@torch.jit.script
@get_torch_jit_decorator()
def get_result(
pose_error,
position_error,

View File

@@ -197,7 +197,7 @@ class MotionGenConfig:
smooth_weight: List[float] = None,
finetune_smooth_weight: Optional[List[float]] = None,
state_finite_difference_mode: Optional[str] = None,
finetune_dt_scale: float = 0.98,
finetune_dt_scale: float = 0.95,
maximum_trajectory_time: Optional[float] = None,
maximum_trajectory_dt: float = 0.1,
velocity_scale: Optional[Union[List[float], float]] = None,
@@ -2086,6 +2086,7 @@ class MotionGen(MotionGenConfig):
# self.trajopt_solver.compute_metrics(not og_evaluate, og_evaluate)
if self.store_debug_in_result:
result.debug_info["trajopt_result"] = traj_result
# run finetune
if plan_config.enable_finetune_trajopt and torch.count_nonzero(traj_result.success) > 0:
with profiler.record_function("motion_gen/finetune_trajopt"):
@@ -2113,6 +2114,9 @@ class MotionGen(MotionGenConfig):
traj_result.solve_time = og_solve_time
if self.store_debug_in_result:
result.debug_info["finetune_trajopt_result"] = traj_result
elif plan_config.enable_finetune_trajopt:
traj_result.success = traj_result.success[:, 0]
result.success = traj_result.success
result.interpolated_plan = traj_result.interpolated_solution
@@ -2190,7 +2194,7 @@ class MotionGen(MotionGenConfig):
# warm up js_trajopt:
goal_state = start_state.clone()
goal_state.position[..., warmup_joint_index] += warmup_joint_delta
for _ in range(2):
for _ in range(3):
self.plan_single_js(start_state, goal_state, MotionGenPlanConfig(max_attempts=1))
if enable_graph:
start_state = JointState.from_position(
@@ -2214,7 +2218,7 @@ class MotionGen(MotionGenConfig):
if n_goalset == -1:
retract_pose = Pose(state.ee_pos_seq, quaternion=state.ee_quat_seq)
start_state.position[..., warmup_joint_index] += warmup_joint_delta
for _ in range(2):
for _ in range(3):
self.plan_single(
start_state,
retract_pose,
@@ -2243,7 +2247,7 @@ class MotionGen(MotionGenConfig):
quaternion=state.ee_quat_seq.repeat(n_goalset, 1).view(1, n_goalset, 4),
)
start_state.position[..., warmup_joint_index] += warmup_joint_delta
for _ in range(2):
for _ in range(3):
self.plan_goalset(
start_state,
retract_pose,
@@ -2278,7 +2282,7 @@ class MotionGen(MotionGenConfig):
retract_pose = Pose(state.ee_pos_seq, quaternion=state.ee_quat_seq)
start_state.position[..., warmup_joint_index] += warmup_joint_delta
for _ in range(2):
for _ in range(3):
if batch_env_mode:
self.plan_batch_env(
start_state,
@@ -2307,7 +2311,7 @@ class MotionGen(MotionGenConfig):
.contiguous(),
)
start_state.position[..., warmup_joint_index] += warmup_joint_delta
for _ in range(2):
for _ in range(3):
if batch_env_mode:
self.plan_batch_env_goalset(
start_state,

View File

@@ -41,6 +41,7 @@ from curobo.types.robot import JointState, RobotConfig
from curobo.types.tensor import T_BDOF, T_DOF, T_BValue_bool, T_BValue_float
from curobo.util.helpers import list_idx_if_not_none
from curobo.util.logger import log_error, log_info, log_warn
from curobo.util.torch_utils import get_torch_jit_decorator, is_torch_compile_available
from curobo.util.trajectory import (
InterpolateType,
calculate_dt_no_clamp,
@@ -877,24 +878,37 @@ class TrajOptSolver(TrajOptSolverConfig):
result.metrics.goalset_index = metrics.goalset_index
st_time = time.time()
feasible = torch.all(result.metrics.feasible, dim=-1)
if result.metrics.cspace_error is None and result.metrics.position_error is None:
raise log_error("convergence check requires either goal_pose or goal_state")
if result.metrics.position_error is not None:
converge = torch.logical_and(
result.metrics.position_error[..., -1] <= self.position_threshold,
result.metrics.rotation_error[..., -1] <= self.rotation_threshold,
)
elif result.metrics.cspace_error is not None:
converge = result.metrics.cspace_error[..., -1] <= self.cspace_threshold
else:
raise ValueError("convergence check requires either goal_pose or goal_state")
success = jit_feasible_success(
result.metrics.feasible,
result.metrics.position_error,
result.metrics.rotation_error,
result.metrics.cspace_error,
self.position_threshold,
self.rotation_threshold,
self.cspace_threshold,
)
if False:
feasible = torch.all(result.metrics.feasible, dim=-1)
success = torch.logical_and(feasible, converge)
if result.metrics.position_error is not None:
converge = torch.logical_and(
result.metrics.position_error[..., -1] <= self.position_threshold,
result.metrics.rotation_error[..., -1] <= self.rotation_threshold,
)
elif result.metrics.cspace_error is not None:
converge = result.metrics.cspace_error[..., -1] <= self.cspace_threshold
else:
raise ValueError("convergence check requires either goal_pose or goal_state")
success = torch.logical_and(feasible, converge)
if return_all_solutions:
traj_result = TrajResult(
success=success,
goal=goal,
solution=result.action.scale(self.solver_dt / opt_dt.view(-1, 1, 1)),
solution=result.action.scale_by_dt(self.solver_dt_tensor, opt_dt.view(-1, 1, 1)),
seed=seed_traj,
position_error=result.metrics.position_error,
rotation_error=result.metrics.rotation_error,
@@ -928,49 +942,86 @@ class TrajOptSolver(TrajOptSolverConfig):
)
with profiler.record_function("trajopt/best_select"):
success[~smooth_label] = False
# get the best solution:
if result.metrics.pose_error is not None:
convergence_error = result.metrics.pose_error[..., -1]
elif result.metrics.cspace_error is not None:
convergence_error = result.metrics.cspace_error[..., -1]
if True: # not get_torch_jit_decorator() == torch.jit.script:
# This only works if torch compile is available:
(
idx,
position_error,
rotation_error,
cspace_error,
goalset_index,
opt_dt,
success,
) = jit_trajopt_best_select(
success,
smooth_label,
result.metrics.cspace_error,
result.metrics.pose_error,
result.metrics.position_error,
result.metrics.rotation_error,
result.metrics.goalset_index,
result.metrics.cost,
smooth_cost,
batch_mode,
goal.batch,
num_seeds,
self._col,
opt_dt,
)
if batch_mode:
last_tstep = [last_tstep[i] for i in idx]
else:
last_tstep = [last_tstep[idx.item()]]
best_act_seq = result.action[idx]
best_raw_action = result.raw_action[idx]
interpolated_traj = interpolated_trajs[idx]
else:
raise ValueError("convergence check requires either goal_pose or goal_state")
running_cost = torch.mean(result.metrics.cost, dim=-1) * 0.0001
error = convergence_error + smooth_cost + running_cost
error[~success] += 10000.0
if batch_mode:
idx = torch.argmin(error.view(goal.batch, num_seeds), dim=-1)
idx = idx + num_seeds * self._col
last_tstep = [last_tstep[i] for i in idx]
success = success[idx]
else:
idx = torch.argmin(error, dim=0)
success[~smooth_label] = False
# get the best solution:
if result.metrics.pose_error is not None:
convergence_error = result.metrics.pose_error[..., -1]
elif result.metrics.cspace_error is not None:
convergence_error = result.metrics.cspace_error[..., -1]
else:
raise ValueError(
"convergence check requires either goal_pose or goal_state"
)
running_cost = torch.mean(result.metrics.cost, dim=-1) * 0.0001
error = convergence_error + smooth_cost + running_cost
error[~success] += 10000.0
if batch_mode:
idx = torch.argmin(error.view(goal.batch, num_seeds), dim=-1)
idx = idx + num_seeds * self._col
last_tstep = [last_tstep[i] for i in idx]
success = success[idx]
else:
idx = torch.argmin(error, dim=0)
last_tstep = [last_tstep[idx.item()]]
success = success[idx : idx + 1]
last_tstep = [last_tstep[idx.item()]]
success = success[idx : idx + 1]
best_act_seq = result.action[idx]
best_raw_action = result.raw_action[idx]
interpolated_traj = interpolated_trajs[idx]
goalset_index = position_error = rotation_error = cspace_error = None
if result.metrics.position_error is not None:
position_error = result.metrics.position_error[idx, -1]
if result.metrics.rotation_error is not None:
rotation_error = result.metrics.rotation_error[idx, -1]
if result.metrics.cspace_error is not None:
cspace_error = result.metrics.cspace_error[idx, -1]
if result.metrics.goalset_index is not None:
goalset_index = result.metrics.goalset_index[idx, -1]
best_act_seq = result.action[idx]
best_raw_action = result.raw_action[idx]
interpolated_traj = interpolated_trajs[idx]
goalset_index = position_error = rotation_error = cspace_error = None
if result.metrics.position_error is not None:
position_error = result.metrics.position_error[idx, -1]
if result.metrics.rotation_error is not None:
rotation_error = result.metrics.rotation_error[idx, -1]
if result.metrics.cspace_error is not None:
cspace_error = result.metrics.cspace_error[idx, -1]
if result.metrics.goalset_index is not None:
goalset_index = result.metrics.goalset_index[idx, -1]
opt_dt = opt_dt[idx]
opt_dt = opt_dt[idx]
if self.sync_cuda_time:
torch.cuda.synchronize()
if len(best_act_seq.shape) == 3:
opt_dt_v = opt_dt.view(-1, 1, 1)
else:
opt_dt_v = opt_dt.view(1, 1)
opt_solution = best_act_seq.scale(self.solver_dt / opt_dt_v)
opt_solution = best_act_seq.scale_by_dt(self.solver_dt_tensor, opt_dt_v)
select_time = time.time() - st_time
debug_info = None
if self.store_debug_in_result:
@@ -1174,7 +1225,7 @@ class TrajOptSolver(TrajOptSolverConfig):
self._max_joint_vel,
self._max_joint_acc,
self._max_joint_jerk,
self.solver_dt,
self.solver_dt_tensor,
kind=self.interpolation_type,
tensor_args=self.tensor_args,
out_traj_state=self._interpolated_traj_buffer,
@@ -1224,7 +1275,12 @@ class TrajOptSolver(TrajOptSolverConfig):
@property
def solver_dt(self):
return self.solver.safety_rollout.dynamics_model.dt_traj_params.base_dt
return self.solver.safety_rollout.dynamics_model.traj_dt[0]
# return self.solver.safety_rollout.dynamics_model.dt_traj_params.base_dt
@property
def solver_dt_tensor(self):
return self.solver.safety_rollout.dynamics_model.traj_dt[0]
def update_solver_dt(
self,
@@ -1254,3 +1310,79 @@ class TrajOptSolver(TrajOptSolverConfig):
for rollout in rollouts
if isinstance(rollout, ArmReacher)
]
@get_torch_jit_decorator()
def jit_feasible_success(
feasible,
position_error: Union[torch.Tensor, None],
rotation_error: Union[torch.Tensor, None],
cspace_error: Union[torch.Tensor, None],
position_threshold: float,
rotation_threshold: float,
cspace_threshold: float,
):
feasible = torch.all(feasible, dim=-1)
converge = feasible
if position_error is not None and rotation_error is not None:
converge = torch.logical_and(
position_error[..., -1] <= position_threshold,
rotation_error[..., -1] <= rotation_threshold,
)
elif cspace_error is not None:
converge = cspace_error[..., -1] <= cspace_threshold
success = torch.logical_and(feasible, converge)
return success
@get_torch_jit_decorator(only_valid_for_compile=True)
def jit_trajopt_best_select(
success,
smooth_label,
cspace_error: Union[torch.Tensor, None],
pose_error: Union[torch.Tensor, None],
position_error: Union[torch.Tensor, None],
rotation_error: Union[torch.Tensor, None],
goalset_index: Union[torch.Tensor, None],
cost,
smooth_cost,
batch_mode: bool,
batch: int,
num_seeds: int,
col,
opt_dt,
):
success[~smooth_label] = False
convergence_error = 0
# get the best solution:
if pose_error is not None:
convergence_error = pose_error[..., -1]
elif cspace_error is not None:
convergence_error = cspace_error[..., -1]
running_cost = torch.mean(cost, dim=-1) * 0.0001
error = convergence_error + smooth_cost + running_cost
error[~success] += 10000.0
if batch_mode:
idx = torch.argmin(error.view(batch, num_seeds), dim=-1)
idx = idx + num_seeds * col
success = success[idx]
else:
idx = torch.argmin(error, dim=0)
success = success[idx : idx + 1]
# goalset_index = position_error = rotation_error = cspace_error = None
if position_error is not None:
position_error = position_error[idx, -1]
if rotation_error is not None:
rotation_error = rotation_error[idx, -1]
if cspace_error is not None:
cspace_error = cspace_error[idx, -1]
if goalset_index is not None:
goalset_index = goalset_index[idx, -1]
opt_dt = opt_dt[idx]
return idx, position_error, rotation_error, cspace_error, goalset_index, opt_dt, success