Significantly improved convergence for mesh and cuboid, new ESDF collision.
This commit is contained in:
@@ -56,8 +56,10 @@ def _get_version():
|
||||
# `importlib_metadata` is the back ported library for older versions of python.
|
||||
# Third Party
|
||||
from importlib_metadata import version
|
||||
|
||||
return version("nvidia_curobo")
|
||||
try:
|
||||
return version("nvidia_curobo")
|
||||
except:
|
||||
return "v0.7.0-no-tag"
|
||||
|
||||
|
||||
# Set `__version__` attribute
|
||||
|
||||
@@ -280,6 +280,7 @@
|
||||
<origin rpy="0 0 0" xyz="0 -0.04 0.0584"/>
|
||||
<axis xyz="0 -1 0"/>
|
||||
<limit effort="20" lower="0.0" upper="0.04" velocity="0.2"/>
|
||||
<mimic joint="panda_finger_joint1"/>
|
||||
</joint>
|
||||
<link name="right_gripper">
|
||||
<inertial>
|
||||
|
||||
@@ -13,7 +13,7 @@ robot_cfg:
|
||||
kinematics:
|
||||
use_usd_kinematics: False
|
||||
isaac_usd_path: "/Isaac/Robots/Franka/franka.usd"
|
||||
usd_path: "robot/franka/franka_panda.usda"
|
||||
usd_path: "robot/non_shipping/franka/franka_panda_meters.usda"
|
||||
usd_robot_root: "/panda"
|
||||
usd_flip_joints: ["panda_joint1","panda_joint2","panda_joint3","panda_joint4", "panda_joint5",
|
||||
"panda_joint6","panda_joint7","panda_finger_joint1", "panda_finger_joint2"]
|
||||
|
||||
@@ -39,7 +39,13 @@ robot_cfg:
|
||||
}
|
||||
urdf_path: robot/iiwa_allegro_description/iiwa.urdf
|
||||
asset_root_path: robot/iiwa_allegro_description
|
||||
|
||||
mesh_link_names:
|
||||
- iiwa7_link_1
|
||||
- iiwa7_link_2
|
||||
- iiwa7_link_3
|
||||
- iiwa7_link_4
|
||||
- iiwa7_link_5
|
||||
- iiwa7_link_6
|
||||
cspace:
|
||||
joint_names:
|
||||
[
|
||||
|
||||
@@ -30,6 +30,25 @@ robot_cfg:
|
||||
- ring_link_3
|
||||
- thumb_link_2
|
||||
- thumb_link_3
|
||||
mesh_link_names:
|
||||
- iiwa7_link_1
|
||||
- iiwa7_link_2
|
||||
- iiwa7_link_3
|
||||
- iiwa7_link_4
|
||||
- iiwa7_link_5
|
||||
- iiwa7_link_6
|
||||
- palm_link
|
||||
- index_link_1
|
||||
- index_link_2
|
||||
- index_link_3
|
||||
- middle_link_1
|
||||
- middle_link_2
|
||||
- middle_link_3
|
||||
- ring_link_1
|
||||
- ring_link_2
|
||||
- ring_link_3
|
||||
- thumb_link_2
|
||||
- thumb_link_3
|
||||
collision_sphere_buffer: 0.005
|
||||
collision_spheres: spheres/iiwa_allegro.yml
|
||||
ee_link: palm_link
|
||||
|
||||
@@ -134,11 +134,11 @@ collision_spheres:
|
||||
"radius": 0.022
|
||||
panda_leftfinger:
|
||||
- "center": [0.0, 0.01, 0.043]
|
||||
"radius": 0.011 # 25
|
||||
"radius": 0.011 #0.025 # 0.011
|
||||
- "center": [0.0, 0.02, 0.015]
|
||||
"radius": 0.011 # 25
|
||||
"radius": 0.011 #0.025 # 0.011
|
||||
panda_rightfinger:
|
||||
- "center": [0.0, -0.01, 0.043]
|
||||
"radius": 0.011 #25
|
||||
"radius": 0.011 #0.025 #0.011
|
||||
- "center": [0.0, -0.02, 0.015]
|
||||
"radius": 0.011 #25
|
||||
"radius": 0.011 #0.025 #0.011
|
||||
@@ -58,7 +58,7 @@ collision_spheres:
|
||||
radius: 0.07
|
||||
tool0:
|
||||
- center: [0, 0, 0.12]
|
||||
radius: 0.05
|
||||
radius: -0.01
|
||||
camera_mount:
|
||||
- center: [0, 0.11, -0.01]
|
||||
radius: 0.06
|
||||
@@ -38,7 +38,7 @@ robot_cfg:
|
||||
'wrist_1_link': 0,
|
||||
'wrist_2_link': 0,
|
||||
'wrist_3_link' : 0,
|
||||
'tool0': 0,
|
||||
'tool0': 0.05,
|
||||
}
|
||||
mesh_link_names: [ 'shoulder_link','upper_arm_link', 'forearm_link', 'wrist_1_link', 'wrist_2_link' ,'wrist_3_link' ]
|
||||
lock_joints: null
|
||||
|
||||
@@ -92,7 +92,7 @@ robot_cfg:
|
||||
"radius": 0.043
|
||||
tool0:
|
||||
- "center": [0.001, 0.001, 0.05]
|
||||
"radius": 0.05
|
||||
"radius": -0.01 #0.05
|
||||
|
||||
|
||||
collision_sphere_buffer: 0.005
|
||||
@@ -109,6 +109,7 @@ robot_cfg:
|
||||
'wrist_1_link': 0,
|
||||
'wrist_2_link': 0,
|
||||
'wrist_3_link' : 0,
|
||||
'tool0': 0.025,
|
||||
}
|
||||
|
||||
use_global_cumul: True
|
||||
|
||||
@@ -55,17 +55,17 @@ cost:
|
||||
|
||||
bound_cfg:
|
||||
weight: [5000.0, 50000.0, 50000.0, 50000.0] # needs to be 3 values
|
||||
smooth_weight: [0.0,5000.0, 50.0, 0.0] # [vel, acc, jerk,]
|
||||
smooth_weight: [0.0,5000.0, 50.0, 0.0] #[0.0,5000.0, 50.0, 0.0] # [vel, acc, jerk,]
|
||||
run_weight_velocity: 0.0
|
||||
run_weight_acceleration: 1.0
|
||||
run_weight_jerk: 1.0
|
||||
activation_distance: [0.1,0.1,0.1,10.0] # for position, velocity, acceleration and jerk
|
||||
activation_distance: [0.1,0.1,0.1,0.1] # for position, velocity, acceleration and jerk
|
||||
null_space_weight: [0.0]
|
||||
|
||||
primitive_collision_cfg:
|
||||
weight: 1000000.0
|
||||
weight: 1000000.0 #1000000.0 1000000
|
||||
use_sweep: True
|
||||
sweep_steps: 6
|
||||
sweep_steps: 4
|
||||
classify: False
|
||||
use_sweep_kernel: True
|
||||
use_speed_metric: True
|
||||
@@ -79,7 +79,7 @@ cost:
|
||||
|
||||
|
||||
lbfgs:
|
||||
n_iters: 400 # 400
|
||||
n_iters: 300 # 400
|
||||
inner_iters: 25
|
||||
cold_start_n_iters: null
|
||||
min_iters: 25
|
||||
@@ -89,7 +89,7 @@ lbfgs:
|
||||
cost_delta_threshold: 1.0
|
||||
cost_relative_threshold: 0.999 #0.999
|
||||
epsilon: 0.01
|
||||
history: 15 #15
|
||||
history: 27 #15
|
||||
use_cuda_graph: True
|
||||
n_problems: 1
|
||||
store_debug: False
|
||||
|
||||
@@ -47,7 +47,7 @@ cost:
|
||||
activation_distance: [0.1]
|
||||
null_space_weight: [1.0]
|
||||
primitive_collision_cfg:
|
||||
weight: 50000.0
|
||||
weight: 5000.0
|
||||
use_sweep: False
|
||||
classify: False
|
||||
activation_distance: 0.01
|
||||
@@ -57,11 +57,11 @@ cost:
|
||||
|
||||
|
||||
lbfgs:
|
||||
n_iters: 100 #60
|
||||
inner_iters: 25
|
||||
n_iters: 80 #60
|
||||
inner_iters: 20
|
||||
cold_start_n_iters: null
|
||||
min_iters: 20
|
||||
line_search_scale: [0.01, 0.3, 0.7, 1.0]
|
||||
line_search_scale: [0.1, 0.3, 0.7, 1.0]
|
||||
fixed_iters: True
|
||||
cost_convergence: 1e-7
|
||||
cost_delta_threshold: 1e-6 #0.0001
|
||||
|
||||
@@ -40,7 +40,7 @@ cost:
|
||||
|
||||
link_pose_cfg:
|
||||
vec_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] # orientation, position for all timesteps
|
||||
run_vec_weight: [1.00,1.00,1.00,1.0,1.0,1.0] # running weight orientation, position
|
||||
run_vec_weight: [0.00,0.00,0.00,0.0,0.0,0.0] # running weight orientation, position
|
||||
weight: [2000,50000.0,30,50] #[150.0, 2000.0, 30, 40]
|
||||
vec_convergence: [0.0,0.0] # orientation, position, orientation metric activation, position metric activation
|
||||
terminal: True
|
||||
@@ -54,19 +54,17 @@ cost:
|
||||
|
||||
bound_cfg:
|
||||
weight: [5000.0, 50000.0, 50000.0,50000.0] # needs to be 3 values
|
||||
#weight: [000.0, 000.0, 000.0,000.0]
|
||||
smooth_weight: [0.0,10000.0,10.0, 0.0] # [vel, acc, jerk, alpha_v-not-used]
|
||||
#smooth_weight: [0.0,0000.0,0.0, 0.0] # [vel, acc, jerk, alpha_v-not-used]
|
||||
run_weight_velocity: 0.00
|
||||
run_weight_acceleration: 1.0
|
||||
run_weight_jerk: 1.0
|
||||
activation_distance: [0.1,0.1,0.1,10.0] # for position, velocity, acceleration and jerk
|
||||
activation_distance: [0.1,0.1,0.1,0.1] # for position, velocity, acceleration and jerk
|
||||
null_space_weight: [0.0]
|
||||
|
||||
primitive_collision_cfg:
|
||||
weight: 100000.0
|
||||
use_sweep: True
|
||||
sweep_steps: 6
|
||||
sweep_steps: 4
|
||||
classify: False
|
||||
use_sweep_kernel: True
|
||||
use_speed_metric: True
|
||||
@@ -81,11 +79,11 @@ cost:
|
||||
|
||||
|
||||
lbfgs:
|
||||
n_iters: 125 #175
|
||||
n_iters: 100 #175
|
||||
inner_iters: 25
|
||||
cold_start_n_iters: null
|
||||
min_iters: 25
|
||||
line_search_scale: [0.01,0.3,0.7,1.0] #[0.01,0.2, 0.3,0.5,0.7,0.9, 1.0] #
|
||||
line_search_scale: [0.01,0.3,0.7,1.0]
|
||||
fixed_iters: True
|
||||
cost_convergence: 0.01
|
||||
cost_delta_threshold: 2000.0
|
||||
|
||||
@@ -33,7 +33,7 @@ graph:
|
||||
sample_pts: 1500
|
||||
node_similarity_distance: 0.1
|
||||
|
||||
rejection_ratio: 20
|
||||
rejection_ratio: 10
|
||||
k_nn: 15
|
||||
max_buffer: 10000
|
||||
max_cg_buffer: 1000
|
||||
|
||||
@@ -34,7 +34,7 @@ cost:
|
||||
run_weight: 1.0
|
||||
link_pose_cfg:
|
||||
vec_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
weight: [30, 50, 10, 10] #[20.0, 100.0]
|
||||
weight: [0, 50, 10, 10] #[20.0, 100.0]
|
||||
vec_convergence: [0.00, 0.000] # orientation, position
|
||||
terminal: False
|
||||
use_metric: True
|
||||
@@ -67,7 +67,7 @@ mppi:
|
||||
cov_type : "DIAG_A" #
|
||||
kappa : 0.01
|
||||
null_act_frac : 0.0
|
||||
sample_mode : 'BEST'
|
||||
sample_mode : 'MEAN'
|
||||
base_action : 'REPEAT'
|
||||
squash_fn : 'CLAMP'
|
||||
n_problems : 1
|
||||
|
||||
@@ -89,7 +89,7 @@ mppi:
|
||||
sample_mode : 'BEST'
|
||||
base_action : 'REPEAT'
|
||||
squash_fn : 'CLAMP'
|
||||
n_problems : 1
|
||||
n_problems : 1
|
||||
use_cuda_graph : True
|
||||
seed : 0
|
||||
store_debug : False
|
||||
|
||||
@@ -39,7 +39,7 @@ cost:
|
||||
|
||||
link_pose_cfg:
|
||||
vec_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
run_vec_weight: [0.00,0.00,0.00,0.0,0.0,0.0] # running weight orientation, position
|
||||
run_vec_weight: [0.0,0.0,0.0,0.0,0.0,0.0] # running weight
|
||||
weight: [0.0, 5000.0, 40, 40]
|
||||
vec_convergence: [0.0,0.0,1000.0,1000.0]
|
||||
terminal: True
|
||||
@@ -63,11 +63,11 @@ cost:
|
||||
|
||||
primitive_collision_cfg:
|
||||
weight: 5000.0
|
||||
use_sweep: True
|
||||
use_sweep: False
|
||||
classify: False
|
||||
sweep_steps: 4
|
||||
use_sweep_kernel: True
|
||||
use_speed_metric: True
|
||||
use_sweep_kernel: False
|
||||
use_speed_metric: False
|
||||
speed_dt: 0.01 # used only for speed metric
|
||||
activation_distance: 0.025
|
||||
|
||||
@@ -92,7 +92,7 @@ mppi:
|
||||
cov_type : "DIAG_A" #
|
||||
kappa : 0.001
|
||||
null_act_frac : 0.0
|
||||
sample_mode : 'BEST'
|
||||
sample_mode : 'MEAN'
|
||||
base_action : 'REPEAT'
|
||||
squash_fn : 'CLAMP'
|
||||
n_problems : 1
|
||||
|
||||
@@ -180,7 +180,7 @@ class CudaRobotModel(CudaRobotModelConfig):
|
||||
self._batch_robot_spheres = torch.zeros(
|
||||
(self._batch_size, self.kinematics_config.total_spheres, 4),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
dtype=self.tensor_args.collision_geometry_dtype,
|
||||
)
|
||||
self._grad_out_q = torch.zeros(
|
||||
(self._batch_size, self.get_dof()),
|
||||
|
||||
@@ -23,6 +23,7 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.math import Pose
|
||||
from curobo.types.state import JointState
|
||||
from curobo.types.tensor import T_DOF
|
||||
from curobo.util.logger import log_error
|
||||
from curobo.util.tensor_util import clone_if_not_none, copy_if_not_none
|
||||
|
||||
|
||||
@@ -138,6 +139,13 @@ class CSpaceConfig:
|
||||
self.acceleration_scale = self.tensor_args.to_device(self.acceleration_scale)
|
||||
if isinstance(self.jerk_scale, List):
|
||||
self.jerk_scale = self.tensor_args.to_device(self.jerk_scale)
|
||||
# check shapes:
|
||||
if self.retract_config is not None:
|
||||
dof = self.retract_config.shape
|
||||
if self.cspace_distance_weight is not None and self.cspace_distance_weight.shape != dof:
|
||||
log_error("cspace_distance_weight shape does not match retract_config")
|
||||
if self.null_space_weight is not None and self.null_space_weight.shape != dof:
|
||||
log_error("null_space_weight shape does not match retract_config")
|
||||
|
||||
def inplace_reindex(self, joint_names: List[str]):
|
||||
new_index = [self.joint_names.index(j) for j in joint_names]
|
||||
@@ -207,8 +215,8 @@ class CSpaceConfig:
|
||||
):
|
||||
retract_config = ((joint_position_upper + joint_position_lower) / 2).flatten()
|
||||
n_dof = retract_config.shape[-1]
|
||||
null_space_weight = torch.ones(n_dof, **vars(tensor_args))
|
||||
cspace_distance_weight = torch.ones(n_dof, **vars(tensor_args))
|
||||
null_space_weight = torch.ones(n_dof, **(tensor_args.as_torch_dict()))
|
||||
cspace_distance_weight = torch.ones(n_dof, **(tensor_args.as_torch_dict()))
|
||||
return CSpaceConfig(
|
||||
joint_names,
|
||||
retract_config,
|
||||
@@ -289,8 +297,8 @@ class KinematicsTensorConfig:
|
||||
retract_config = (
|
||||
(self.joint_limits.position[1] + self.joint_limits.position[0]) / 2
|
||||
).flatten()
|
||||
null_space_weight = torch.ones(self.n_dof, **vars(self.tensor_args))
|
||||
cspace_distance_weight = torch.ones(self.n_dof, **vars(self.tensor_args))
|
||||
null_space_weight = torch.ones(self.n_dof, **(self.tensor_args.as_torch_dict()))
|
||||
cspace_distance_weight = torch.ones(self.n_dof, **(self.tensor_args.as_torch_dict()))
|
||||
joint_names = self.joint_names
|
||||
self.cspace = CSpaceConfig(
|
||||
joint_names,
|
||||
|
||||
@@ -175,7 +175,11 @@ class UrdfKinematicsParser(KinematicsParser):
|
||||
return txt
|
||||
|
||||
def get_link_mesh(self, link_name):
|
||||
m = self._robot.link_map[link_name].visuals[0].geometry.mesh
|
||||
link_data = self._robot.link_map[link_name]
|
||||
|
||||
if len(link_data.visuals) == 0:
|
||||
log_error(link_name + " not found in urdf, remove from mesh_link_names")
|
||||
m = link_data.visuals[0].geometry.mesh
|
||||
mesh_pose = self._robot.link_map[link_name].visuals[0].origin
|
||||
# read visual material:
|
||||
if mesh_pose is None:
|
||||
|
||||
@@ -57,7 +57,8 @@ std::vector<torch::Tensor>swept_sphere_obb_clpt(
|
||||
const bool enable_speed_metric,
|
||||
const bool transform_back,
|
||||
const bool compute_distance,
|
||||
const bool use_batch_env);
|
||||
const bool use_batch_env,
|
||||
const bool sum_collisions);
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
sphere_obb_clpt(const torch::Tensor sphere_position, // batch_size, 4
|
||||
@@ -66,6 +67,7 @@ sphere_obb_clpt(const torch::Tensor sphere_position, // batch_size, 4
|
||||
torch::Tensor sparsity_idx,
|
||||
const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance,
|
||||
const torch::Tensor max_distance,
|
||||
const torch::Tensor obb_accel, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_bounds, // n_boxes, 3
|
||||
const torch::Tensor obb_pose, // n_boxes, 4, 4
|
||||
@@ -78,8 +80,52 @@ sphere_obb_clpt(const torch::Tensor sphere_position, // batch_size, 4
|
||||
const int n_spheres,
|
||||
const bool transform_back,
|
||||
const bool compute_distance,
|
||||
const bool use_batch_env);
|
||||
const bool use_batch_env,
|
||||
const bool sum_collisions,
|
||||
const bool compute_esdf);
|
||||
std::vector<torch::Tensor>
|
||||
sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
|
||||
torch::Tensor distance,
|
||||
torch::Tensor closest_point, // batch size, 3
|
||||
torch::Tensor sparsity_idx, const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance,
|
||||
const torch::Tensor max_distance,
|
||||
const torch::Tensor grid_features, // n_boxes, 4, 4
|
||||
const torch::Tensor grid_params, // n_boxes, 3
|
||||
const torch::Tensor grid_pose, // n_boxes, 4, 4
|
||||
const torch::Tensor grid_enable, // n_boxes, 4, 4
|
||||
const torch::Tensor n_env_grid,
|
||||
const torch::Tensor env_query_idx, // n_boxes, 4, 4
|
||||
const int max_nobs, const int batch_size, const int horizon,
|
||||
const int n_spheres, const bool transform_back,
|
||||
const bool compute_distance, const bool use_batch_env,
|
||||
const bool sum_collisions,
|
||||
const bool compute_esdf);
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
swept_sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
|
||||
torch::Tensor distance,
|
||||
torch::Tensor closest_point, // batch size, 3
|
||||
torch::Tensor sparsity_idx, const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance,
|
||||
const torch::Tensor max_distance,
|
||||
const torch::Tensor speed_dt,
|
||||
const torch::Tensor grid_features, // n_boxes, 4, 4
|
||||
const torch::Tensor grid_params, // n_boxes, 3
|
||||
const torch::Tensor grid_pose, // n_boxes, 4, 4
|
||||
const torch::Tensor grid_enable, // n_boxes, 4, 4
|
||||
const torch::Tensor n_env_grid,
|
||||
const torch::Tensor env_query_idx, // n_boxes, 4, 4
|
||||
const int max_nobs,
|
||||
const int batch_size,
|
||||
const int horizon,
|
||||
const int n_spheres,
|
||||
const int sweep_steps,
|
||||
const bool enable_speed_metric,
|
||||
const bool transform_back,
|
||||
const bool compute_distance,
|
||||
const bool use_batch_env,
|
||||
const bool sum_collisions);
|
||||
std::vector<torch::Tensor>pose_distance(
|
||||
torch::Tensor out_distance,
|
||||
torch::Tensor out_position_distance,
|
||||
@@ -159,11 +205,11 @@ std::vector<torch::Tensor>self_collision_distance_wrapper(
|
||||
|
||||
std::vector<torch::Tensor>sphere_obb_clpt_wrapper(
|
||||
const torch::Tensor sphere_position, // batch_size, 4
|
||||
|
||||
torch::Tensor distance,
|
||||
torch::Tensor closest_point, // batch size, 3
|
||||
torch::Tensor sparsity_idx, const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance,
|
||||
const torch::Tensor max_distance,
|
||||
const torch::Tensor obb_accel, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_bounds, // n_boxes, 3
|
||||
const torch::Tensor obb_pose, // n_boxes, 4, 4
|
||||
@@ -171,8 +217,10 @@ std::vector<torch::Tensor>sphere_obb_clpt_wrapper(
|
||||
const torch::Tensor n_env_obb, // n_boxes, 4, 4
|
||||
const torch::Tensor env_query_idx, // n_boxes, 4, 4
|
||||
const int max_nobs, const int batch_size, const int horizon,
|
||||
const int n_spheres, const bool transform_back, const bool compute_distance,
|
||||
const bool use_batch_env)
|
||||
const int n_spheres,
|
||||
const bool transform_back, const bool compute_distance,
|
||||
const bool use_batch_env, const bool sum_collisions = true,
|
||||
const bool compute_esdf = false)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(sphere_position.device());
|
||||
|
||||
@@ -185,9 +233,9 @@ std::vector<torch::Tensor>sphere_obb_clpt_wrapper(
|
||||
CHECK_INPUT(obb_accel);
|
||||
return sphere_obb_clpt(
|
||||
sphere_position, distance, closest_point, sparsity_idx, weight,
|
||||
activation_distance, obb_accel, obb_bounds, obb_pose, obb_enable,
|
||||
activation_distance, max_distance, obb_accel, obb_bounds, obb_pose, obb_enable,
|
||||
n_env_obb, env_query_idx, max_nobs, batch_size, horizon, n_spheres,
|
||||
transform_back, compute_distance, use_batch_env);
|
||||
transform_back, compute_distance, use_batch_env, sum_collisions, compute_esdf);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>swept_sphere_obb_clpt_wrapper(
|
||||
@@ -205,7 +253,7 @@ std::vector<torch::Tensor>swept_sphere_obb_clpt_wrapper(
|
||||
const int max_nobs, const int batch_size, const int horizon,
|
||||
const int n_spheres, const int sweep_steps, const bool enable_speed_metric,
|
||||
const bool transform_back, const bool compute_distance,
|
||||
const bool use_batch_env)
|
||||
const bool use_batch_env, const bool sum_collisions = true)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(sphere_position.device());
|
||||
|
||||
@@ -218,7 +266,37 @@ std::vector<torch::Tensor>swept_sphere_obb_clpt_wrapper(
|
||||
distance, closest_point, sparsity_idx, weight, activation_distance,
|
||||
speed_dt, obb_accel, obb_bounds, obb_pose, obb_enable, n_env_obb,
|
||||
env_query_idx, max_nobs, batch_size, horizon, n_spheres, sweep_steps,
|
||||
enable_speed_metric, transform_back, compute_distance, use_batch_env);
|
||||
enable_speed_metric, transform_back, compute_distance, use_batch_env, sum_collisions);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
sphere_voxel_clpt_wrapper(const torch::Tensor sphere_position, // batch_size, 3
|
||||
torch::Tensor distance,
|
||||
torch::Tensor closest_point, // batch size, 3
|
||||
torch::Tensor sparsity_idx, const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance,
|
||||
const torch::Tensor max_distance,
|
||||
const torch::Tensor grid_features, // n_boxes, 4, 4
|
||||
const torch::Tensor grid_params, // n_boxes, 3
|
||||
const torch::Tensor grid_pose, // n_boxes, 4, 4
|
||||
const torch::Tensor grid_enable, // n_boxes, 4, 4
|
||||
const torch::Tensor n_env_grid,
|
||||
const torch::Tensor env_query_idx, // n_boxes, 4, 4
|
||||
const int max_ngrid, const int batch_size, const int horizon,
|
||||
const int n_spheres, const bool transform_back,
|
||||
const bool compute_distance, const bool use_batch_env,
|
||||
const bool sum_collisions,
|
||||
const bool compute_esdf)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(sphere_position.device());
|
||||
|
||||
CHECK_INPUT(distance);
|
||||
CHECK_INPUT(closest_point);
|
||||
CHECK_INPUT(sphere_position);
|
||||
return sphere_voxel_clpt(sphere_position, distance, closest_point, sparsity_idx, weight,
|
||||
activation_distance, max_distance, grid_features, grid_params,
|
||||
grid_pose, grid_enable, n_env_grid, env_query_idx, max_ngrid, batch_size, horizon, n_spheres,
|
||||
transform_back, compute_distance, use_batch_env, sum_collisions, compute_esdf);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>pose_distance_wrapper(
|
||||
@@ -297,6 +375,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
"Closest Point OBB(curobolib)");
|
||||
m.def("swept_closest_point", &swept_sphere_obb_clpt_wrapper,
|
||||
"Swept Closest Point OBB(curobolib)");
|
||||
m.def("closest_point_voxel", &sphere_voxel_clpt_wrapper,
|
||||
"Closest Point Voxel(curobolib)");
|
||||
m.def("swept_closest_point_voxel", &swept_sphere_voxel_clpt,
|
||||
"Swpet Closest Point Voxel(curobolib)");
|
||||
|
||||
|
||||
|
||||
m.def("self_collision_distance", &self_collision_distance_wrapper,
|
||||
"Self Collision Distance (curobolib)");
|
||||
|
||||
@@ -14,38 +14,6 @@
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
std::vector<torch::Tensor>reduce_cuda(torch::Tensor vec,
|
||||
torch::Tensor vec2,
|
||||
torch::Tensor rho_buffer,
|
||||
torch::Tensor sum,
|
||||
const int batch_size,
|
||||
const int v_dim,
|
||||
const int m);
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
lbfgs_step_cuda(torch::Tensor step_vec,
|
||||
torch::Tensor rho_buffer,
|
||||
torch::Tensor y_buffer,
|
||||
torch::Tensor s_buffer,
|
||||
torch::Tensor grad_q,
|
||||
const float epsilon,
|
||||
const int batch_size,
|
||||
const int m,
|
||||
const int v_dim);
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
lbfgs_update_cuda(torch::Tensor rho_buffer,
|
||||
torch::Tensor y_buffer,
|
||||
torch::Tensor s_buffer,
|
||||
torch::Tensor q,
|
||||
torch::Tensor grad_q,
|
||||
torch::Tensor x_0,
|
||||
torch::Tensor grad_0,
|
||||
const int batch_size,
|
||||
const int m,
|
||||
const int v_dim);
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
lbfgs_cuda_fuse(torch::Tensor step_vec,
|
||||
torch::Tensor rho_buffer,
|
||||
@@ -59,7 +27,8 @@ lbfgs_cuda_fuse(torch::Tensor step_vec,
|
||||
const int batch_size,
|
||||
const int m,
|
||||
const int v_dim,
|
||||
const bool stable_mode);
|
||||
const bool stable_mode,
|
||||
const bool use_shared_buffers);
|
||||
|
||||
// C++ interface
|
||||
|
||||
@@ -71,58 +40,12 @@ lbfgs_cuda_fuse(torch::Tensor step_vec,
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
lbfgs_step_call(torch::Tensor step_vec, torch::Tensor rho_buffer,
|
||||
torch::Tensor y_buffer, torch::Tensor s_buffer,
|
||||
torch::Tensor grad_q, const float epsilon, const int batch_size,
|
||||
const int m, const int v_dim)
|
||||
{
|
||||
CHECK_INPUT(step_vec);
|
||||
CHECK_INPUT(rho_buffer);
|
||||
CHECK_INPUT(y_buffer);
|
||||
CHECK_INPUT(s_buffer);
|
||||
CHECK_INPUT(grad_q);
|
||||
const at::cuda::OptionalCUDAGuard guard(grad_q.device());
|
||||
|
||||
return lbfgs_step_cuda(step_vec, rho_buffer, y_buffer, s_buffer, grad_q,
|
||||
epsilon, batch_size, m, v_dim);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
lbfgs_update_call(torch::Tensor rho_buffer, torch::Tensor y_buffer,
|
||||
torch::Tensor s_buffer, torch::Tensor q, torch::Tensor grad_q,
|
||||
torch::Tensor x_0, torch::Tensor grad_0, const int batch_size,
|
||||
const int m, const int v_dim)
|
||||
{
|
||||
CHECK_INPUT(rho_buffer);
|
||||
CHECK_INPUT(y_buffer);
|
||||
CHECK_INPUT(s_buffer);
|
||||
CHECK_INPUT(grad_q);
|
||||
CHECK_INPUT(x_0);
|
||||
CHECK_INPUT(grad_0);
|
||||
CHECK_INPUT(q);
|
||||
const at::cuda::OptionalCUDAGuard guard(grad_q.device());
|
||||
|
||||
return lbfgs_update_cuda(rho_buffer, y_buffer, s_buffer, q, grad_q, x_0,
|
||||
grad_0, batch_size, m, v_dim);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
reduce_cuda_call(torch::Tensor vec, torch::Tensor vec2,
|
||||
torch::Tensor rho_buffer, torch::Tensor sum,
|
||||
const int batch_size, const int v_dim, const int m)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(sum.device());
|
||||
|
||||
return reduce_cuda(vec, vec2, rho_buffer, sum, batch_size, v_dim, m);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
lbfgs_call(torch::Tensor step_vec, torch::Tensor rho_buffer,
|
||||
torch::Tensor y_buffer, torch::Tensor s_buffer, torch::Tensor q,
|
||||
torch::Tensor grad_q, torch::Tensor x_0, torch::Tensor grad_0,
|
||||
const float epsilon, const int batch_size, const int m,
|
||||
const int v_dim, const bool stable_mode)
|
||||
const int v_dim, const bool stable_mode, const bool use_shared_buffers)
|
||||
{
|
||||
CHECK_INPUT(step_vec);
|
||||
CHECK_INPUT(rho_buffer);
|
||||
@@ -136,13 +59,11 @@ lbfgs_call(torch::Tensor step_vec, torch::Tensor rho_buffer,
|
||||
|
||||
return lbfgs_cuda_fuse(step_vec, rho_buffer, y_buffer, s_buffer, q, grad_q,
|
||||
x_0, grad_0, epsilon, batch_size, m, v_dim,
|
||||
stable_mode);
|
||||
stable_mode, use_shared_buffers);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("step", &lbfgs_step_call, "L-BFGS step (CUDA)");
|
||||
m.def("update", &lbfgs_update_call, "L-BFGS Update (CUDA)");
|
||||
|
||||
m.def("forward", &lbfgs_call, "L-BFGS Update + Step (CUDA)");
|
||||
m.def("debug_reduce", &reduce_cuda_call, "L-BFGS Debug");
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,6 @@
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cuda.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
@@ -185,11 +185,11 @@ namespace Curobo
|
||||
if (coll_matrix[i * nspheres + j] == 1)
|
||||
{
|
||||
float4 sph1 = __rs_shared[i];
|
||||
|
||||
if ((sph1.w <= 0.0) || (sph2.w <= 0.0))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
//
|
||||
//if ((sph1.w <= 0.0) || (sph2.w <= 0.0))
|
||||
//{
|
||||
// continue;
|
||||
//}
|
||||
float r_diff = sph1.w + sph2.w;
|
||||
float d = sqrt((sph1.x - sph2.x) * (sph1.x - sph2.x) +
|
||||
(sph1.y - sph2.y) * (sph1.y - sph2.y) +
|
||||
@@ -380,10 +380,10 @@ namespace Curobo
|
||||
float4 sph1 = __rs_shared[NBPB * i + l];
|
||||
float4 sph2 = __rs_shared[NBPB * j + l];
|
||||
|
||||
if ((sph1.w <= 0.0) || (sph2.w <= 0.0))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
//if ((sph1.w <= 0.0) || (sph2.w <= 0.0))
|
||||
//{
|
||||
// continue;
|
||||
//}
|
||||
float r_diff =
|
||||
sph1.w + sph2.w; // sum of two radii, radii include respective offsets
|
||||
float d = sqrt((sph1.x - sph2.x) * (sph1.x - sph2.x) +
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -342,10 +342,8 @@ namespace Curobo
|
||||
float out_pos = 0.0, out_vel = 0.0, out_acc = 0.0, out_jerk = 0.0;
|
||||
float st_pos = 0.0, st_vel = 0.0, st_acc = 0.0;
|
||||
|
||||
const int b_addrs = b_idx * horizon * dof;
|
||||
const int b_addrs_action = b_idx * (horizon - 4) * dof;
|
||||
float in_pos[5]; // create a 5 value scalar
|
||||
const float acc_scale = 1.0;
|
||||
|
||||
#pragma unroll 5
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.logger import log_warn
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
try:
|
||||
# CuRobo
|
||||
@@ -235,7 +236,7 @@ def get_pose_distance_backward(
|
||||
return r[0], r[1]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def backward_PoseError_jit(grad_g_dist, grad_out_distance, weight, g_vec):
|
||||
grad_vec = grad_g_dist + (grad_out_distance * weight)
|
||||
grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec
|
||||
@@ -243,7 +244,7 @@ def backward_PoseError_jit(grad_g_dist, grad_out_distance, weight, g_vec):
|
||||
|
||||
|
||||
# full method:
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def backward_full_PoseError_jit(
|
||||
grad_out_distance, grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
|
||||
):
|
||||
@@ -570,6 +571,7 @@ class SdfSphereOBB(torch.autograd.Function):
|
||||
sparsity_idx,
|
||||
weight,
|
||||
activation_distance,
|
||||
max_distance,
|
||||
box_accel,
|
||||
box_dims,
|
||||
box_pose,
|
||||
@@ -584,6 +586,8 @@ class SdfSphereOBB(torch.autograd.Function):
|
||||
compute_distance,
|
||||
use_batch_env,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
r = geom_cu.closest_point(
|
||||
query_sphere,
|
||||
@@ -592,6 +596,7 @@ class SdfSphereOBB(torch.autograd.Function):
|
||||
sparsity_idx,
|
||||
weight,
|
||||
activation_distance,
|
||||
max_distance,
|
||||
box_accel,
|
||||
box_dims,
|
||||
box_pose,
|
||||
@@ -605,8 +610,11 @@ class SdfSphereOBB(torch.autograd.Function):
|
||||
transform_back,
|
||||
compute_distance,
|
||||
use_batch_env,
|
||||
sum_collisions,
|
||||
compute_esdf,
|
||||
)
|
||||
# r[1][r[1]!=r[1]] = 0.0
|
||||
ctx.compute_esdf = compute_esdf
|
||||
ctx.return_loss = return_loss
|
||||
ctx.save_for_backward(r[1])
|
||||
return r[0]
|
||||
@@ -615,6 +623,8 @@ class SdfSphereOBB(torch.autograd.Function):
|
||||
def backward(ctx, grad_output):
|
||||
grad_pt = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
# if ctx.compute_esdf:
|
||||
# raise NotImplementedError("Gradients not implemented for compute_esdf=True")
|
||||
(r,) = ctx.saved_tensors
|
||||
if ctx.return_loss:
|
||||
r = r * grad_output.unsqueeze(-1)
|
||||
@@ -640,6 +650,9 @@ class SdfSphereOBB(torch.autograd.Function):
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@@ -670,6 +683,7 @@ class SdfSweptSphereOBB(torch.autograd.Function):
|
||||
compute_distance,
|
||||
use_batch_env,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
r = geom_cu.swept_closest_point(
|
||||
query_sphere,
|
||||
@@ -694,6 +708,7 @@ class SdfSweptSphereOBB(torch.autograd.Function):
|
||||
transform_back,
|
||||
compute_distance,
|
||||
use_batch_env,
|
||||
sum_collisions,
|
||||
)
|
||||
ctx.return_loss = return_loss
|
||||
ctx.save_for_backward(
|
||||
@@ -733,4 +748,200 @@ class SdfSweptSphereOBB(torch.autograd.Function):
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class SdfSphereVoxel(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
query_sphere,
|
||||
out_buffer,
|
||||
grad_out_buffer,
|
||||
sparsity_idx,
|
||||
weight,
|
||||
activation_distance,
|
||||
max_distance,
|
||||
grid_features,
|
||||
grid_params,
|
||||
grid_pose,
|
||||
grid_enable,
|
||||
n_env_grid,
|
||||
env_query_idx,
|
||||
max_nobs,
|
||||
batch_size,
|
||||
horizon,
|
||||
n_spheres,
|
||||
transform_back,
|
||||
compute_distance,
|
||||
use_batch_env,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
|
||||
r = geom_cu.closest_point_voxel(
|
||||
query_sphere,
|
||||
out_buffer,
|
||||
grad_out_buffer,
|
||||
sparsity_idx,
|
||||
weight,
|
||||
activation_distance,
|
||||
max_distance,
|
||||
grid_features,
|
||||
grid_params,
|
||||
grid_pose,
|
||||
grid_enable,
|
||||
n_env_grid,
|
||||
env_query_idx,
|
||||
max_nobs,
|
||||
batch_size,
|
||||
horizon,
|
||||
n_spheres,
|
||||
transform_back,
|
||||
compute_distance,
|
||||
use_batch_env,
|
||||
sum_collisions,
|
||||
compute_esdf,
|
||||
)
|
||||
ctx.compute_esdf = compute_esdf
|
||||
ctx.return_loss = return_loss
|
||||
ctx.save_for_backward(r[1])
|
||||
return r[0]
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_pt = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
# if ctx.compute_esdf:
|
||||
# raise NotImplementedError("Gradients not implemented for compute_esdf=True")
|
||||
(r,) = ctx.saved_tensors
|
||||
if ctx.return_loss:
|
||||
r = r * grad_output.unsqueeze(-1)
|
||||
grad_pt = r
|
||||
return (
|
||||
grad_pt,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class SdfSweptSphereVoxel(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
query_sphere,
|
||||
out_buffer,
|
||||
grad_out_buffer,
|
||||
sparsity_idx,
|
||||
weight,
|
||||
activation_distance,
|
||||
max_distance,
|
||||
speed_dt,
|
||||
grid_features,
|
||||
grid_params,
|
||||
grid_pose,
|
||||
grid_enable,
|
||||
n_env_grid,
|
||||
env_query_idx,
|
||||
max_nobs,
|
||||
batch_size,
|
||||
horizon,
|
||||
n_spheres,
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
transform_back,
|
||||
compute_distance,
|
||||
use_batch_env,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
r = geom_cu.swept_closest_point_voxel(
|
||||
query_sphere,
|
||||
out_buffer,
|
||||
grad_out_buffer,
|
||||
sparsity_idx,
|
||||
weight,
|
||||
activation_distance,
|
||||
max_distance,
|
||||
speed_dt,
|
||||
grid_features,
|
||||
grid_params,
|
||||
grid_pose,
|
||||
grid_enable,
|
||||
n_env_grid,
|
||||
env_query_idx,
|
||||
max_nobs,
|
||||
batch_size,
|
||||
horizon,
|
||||
n_spheres,
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
transform_back,
|
||||
compute_distance,
|
||||
use_batch_env,
|
||||
sum_collisions,
|
||||
)
|
||||
|
||||
ctx.return_loss = return_loss
|
||||
ctx.save_for_backward(
|
||||
r[1],
|
||||
)
|
||||
return r[0]
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_pt = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
(r,) = ctx.saved_tensors
|
||||
if ctx.return_loss:
|
||||
r = r * grad_output.unsqueeze(-1)
|
||||
grad_pt = r
|
||||
return (
|
||||
grad_pt,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -39,7 +39,8 @@ except ImportError:
|
||||
|
||||
class LBFGScu(Function):
|
||||
@staticmethod
|
||||
def _call_cuda(
|
||||
def forward(
|
||||
ctx,
|
||||
step_vec,
|
||||
rho_buffer,
|
||||
y_buffer,
|
||||
@@ -50,6 +51,7 @@ class LBFGScu(Function):
|
||||
grad_0,
|
||||
epsilon=0.1,
|
||||
stable_mode=False,
|
||||
use_shared_buffers=True,
|
||||
):
|
||||
m, b, v_dim, _ = y_buffer.shape
|
||||
|
||||
@@ -67,39 +69,12 @@ class LBFGScu(Function):
|
||||
m,
|
||||
v_dim,
|
||||
stable_mode,
|
||||
use_shared_buffers,
|
||||
)
|
||||
step_v = R[0].view(step_vec.shape)
|
||||
|
||||
return step_v
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
step_vec,
|
||||
rho_buffer,
|
||||
y_buffer,
|
||||
s_buffer,
|
||||
q,
|
||||
grad_q,
|
||||
x_0,
|
||||
grad_0,
|
||||
epsilon=0.1,
|
||||
stable_mode=False,
|
||||
):
|
||||
R = LBFGScu._call_cuda(
|
||||
step_vec,
|
||||
rho_buffer,
|
||||
y_buffer,
|
||||
s_buffer,
|
||||
q,
|
||||
grad_q,
|
||||
x_0,
|
||||
grad_0,
|
||||
epsilon=epsilon,
|
||||
stable_mode=stable_mode,
|
||||
)
|
||||
# ctx.save_for_backward(batch_spheres, robot_spheres, link_mats, link_sphere_map)
|
||||
return R
|
||||
return step_v
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
@@ -109,4 +84,5 @@ class LBFGScu(Function):
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -12,8 +12,11 @@
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
@torch.jit.script
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def project_depth_to_pointcloud(depth_image: torch.Tensor, intrinsics_matrix: torch.Tensor):
|
||||
"""Projects numpy depth image to point cloud.
|
||||
|
||||
@@ -43,7 +46,7 @@ def project_depth_to_pointcloud(depth_image: torch.Tensor, intrinsics_matrix: to
|
||||
return raw_pc
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor):
|
||||
"""Projects numpy depth image to point cloud.
|
||||
|
||||
@@ -54,10 +57,10 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
|
||||
Returns:
|
||||
array of float (h, w, 3)
|
||||
"""
|
||||
fx = intrinsics_matrix[:, 0, 0]
|
||||
fy = intrinsics_matrix[:, 1, 1]
|
||||
cx = intrinsics_matrix[:, 0, 2]
|
||||
cy = intrinsics_matrix[:, 1, 2]
|
||||
fx = intrinsics_matrix[:, 0:1, 0:1]
|
||||
fy = intrinsics_matrix[:, 1:2, 1:2]
|
||||
cx = intrinsics_matrix[:, 0:1, 2:3]
|
||||
cy = intrinsics_matrix[:, 1:2, 2:3]
|
||||
|
||||
input_x = torch.arange(width, dtype=torch.float32, device=intrinsics_matrix.device)
|
||||
input_y = torch.arange(height, dtype=torch.float32, device=intrinsics_matrix.device)
|
||||
@@ -73,7 +76,6 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
|
||||
device=intrinsics_matrix.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
output_x = (input_x - cx) / fx
|
||||
output_y = (input_y - cy) / fy
|
||||
|
||||
@@ -84,7 +86,7 @@ def get_projection_rays(height: int, width: int, intrinsics_matrix: torch.Tensor
|
||||
return rays
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def project_pointcloud_to_depth(
|
||||
pointcloud: torch.Tensor,
|
||||
output_image: torch.Tensor,
|
||||
@@ -106,7 +108,7 @@ def project_pointcloud_to_depth(
|
||||
return output_image
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def project_depth_using_rays(
|
||||
depth_image: torch.Tensor, rays: torch.Tensor, filter_origin: bool = False
|
||||
):
|
||||
|
||||
@@ -11,8 +11,10 @@
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# @torch.jit.script
|
||||
|
||||
# @get_torch_jit_decorator()
|
||||
def lookup_distance(pt, dist_matrix_flat, num_voxels):
|
||||
# flatten:
|
||||
ind_pt = (
|
||||
@@ -22,7 +24,7 @@ def lookup_distance(pt, dist_matrix_flat, num_voxels):
|
||||
return dist
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def compute_sdf_gradient(pt, dist_matrix_flat, num_voxels, dist):
|
||||
grad_l = []
|
||||
for i in range(3): # x,y,z
|
||||
|
||||
@@ -30,6 +30,11 @@ def create_collision_checker(config: WorldCollisionConfig):
|
||||
from curobo.geom.sdf.world_mesh import WorldMeshCollision
|
||||
|
||||
return WorldMeshCollision(config)
|
||||
elif config.checker_type == CollisionCheckerType.VOXEL:
|
||||
# CuRobo
|
||||
from curobo.geom.sdf.world_voxel import WorldVoxelCollision
|
||||
|
||||
return WorldVoxelCollision(config)
|
||||
else:
|
||||
log_error("Not implemented", exc_info=True)
|
||||
log_error("Unknown Collision Checker type: " + config.checker_type, exc_info=True)
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -16,284 +16,6 @@ import warp as wp
|
||||
wp.set_module_options({"fast_math": False})
|
||||
|
||||
|
||||
# create warp kernels:
|
||||
@wp.kernel
|
||||
def get_swept_closest_pt(
|
||||
pt: wp.array(dtype=wp.vec4),
|
||||
distance: wp.array(dtype=wp.float32), # this stores the output cost
|
||||
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
|
||||
sparsity_idx: wp.array(dtype=wp.uint8),
|
||||
weight: wp.array(dtype=wp.float32),
|
||||
activation_distance: wp.array(dtype=wp.float32), # eta threshold
|
||||
speed_dt: wp.array(dtype=wp.float32),
|
||||
mesh: wp.array(dtype=wp.uint64),
|
||||
mesh_pose: wp.array(dtype=wp.float32),
|
||||
mesh_enable: wp.array(dtype=wp.uint8),
|
||||
n_env_mesh: wp.array(dtype=wp.int32),
|
||||
max_dist: wp.float32,
|
||||
write_grad: wp.uint8,
|
||||
batch_size: wp.int32,
|
||||
horizon: wp.int32,
|
||||
nspheres: wp.int32,
|
||||
max_nmesh: wp.int32,
|
||||
sweep_steps: wp.uint8,
|
||||
enable_speed_metric: wp.uint8,
|
||||
):
|
||||
# we launch nspheres kernels
|
||||
# compute gradient here and return
|
||||
# distance is negative outside and positive inside
|
||||
tid = int(0)
|
||||
tid = wp.tid()
|
||||
|
||||
b_idx = int(0)
|
||||
h_idx = int(0)
|
||||
sph_idx = int(0)
|
||||
# read horizon
|
||||
eta = float(0.0) # 5cm buffer
|
||||
dt = float(1.0)
|
||||
|
||||
b_idx = tid / (horizon * nspheres)
|
||||
|
||||
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
|
||||
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
|
||||
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
|
||||
return
|
||||
|
||||
n_mesh = int(0)
|
||||
# $wp.printf("%d, %d, %d, %d \n", tid, b_idx, h_idx, sph_idx)
|
||||
# read sphere
|
||||
sphere_0_distance = float(0.0)
|
||||
sphere_2_distance = float(0.0)
|
||||
|
||||
sphere_0 = wp.vec3(0.0)
|
||||
sphere_2 = wp.vec3(0.0)
|
||||
sphere_int = wp.vec3(0.0)
|
||||
sphere_temp = wp.vec3(0.0)
|
||||
k0 = float(0.0)
|
||||
face_index = int(0)
|
||||
face_u = float(0.0)
|
||||
face_v = float(0.0)
|
||||
sign = float(0.0)
|
||||
dist = float(0.0)
|
||||
|
||||
uint_zero = wp.uint8(0)
|
||||
uint_one = wp.uint8(1)
|
||||
cl_pt = wp.vec3(0.0)
|
||||
local_pt = wp.vec3(0.0)
|
||||
in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
|
||||
in_rad = in_sphere[3]
|
||||
if in_rad < 0.0:
|
||||
distance[tid] = 0.0
|
||||
if write_grad == 1 and sparsity_idx[tid] == uint_one:
|
||||
sparsity_idx[tid] = uint_zero
|
||||
closest_pt[tid * 4] = 0.0
|
||||
closest_pt[tid * 4 + 1] = 0.0
|
||||
closest_pt[tid * 4 + 2] = 0.0
|
||||
|
||||
return
|
||||
eta = activation_distance[0]
|
||||
dt = speed_dt[0]
|
||||
in_rad += eta
|
||||
|
||||
max_dist_buffer = float(0.0)
|
||||
max_dist_buffer = max_dist
|
||||
if in_rad > max_dist_buffer:
|
||||
max_dist_buffer += in_rad
|
||||
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
|
||||
|
||||
# read in sphere 0:
|
||||
# read in sphere 0:
|
||||
if h_idx > 0:
|
||||
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx - 1) * nspheres) + sph_idx]
|
||||
sphere_0 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
|
||||
sphere_0_distance = wp.length(sphere_0 - in_pt) / 2.0
|
||||
|
||||
if h_idx < horizon - 1:
|
||||
in_sphere = pt[b_idx * horizon * nspheres + ((h_idx + 1) * nspheres) + sph_idx]
|
||||
sphere_2 += wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
|
||||
sphere_2_distance = wp.length(sphere_2 - in_pt) / 2.0
|
||||
|
||||
# read in sphere 2:
|
||||
closest_distance = float(0.0)
|
||||
closest_point = wp.vec3(0.0)
|
||||
i = int(0)
|
||||
dis_length = float(0.0)
|
||||
jump_distance = float(0.0)
|
||||
mid_distance = float(0.0)
|
||||
n_mesh = n_env_mesh[0]
|
||||
obj_position = wp.vec3()
|
||||
|
||||
while i < n_mesh:
|
||||
if mesh_enable[i] == uint_one:
|
||||
obj_position[0] = mesh_pose[i * 8 + 0]
|
||||
obj_position[1] = mesh_pose[i * 8 + 1]
|
||||
obj_position[2] = mesh_pose[i * 8 + 2]
|
||||
obj_quat = wp.quaternion(
|
||||
mesh_pose[i * 8 + 4],
|
||||
mesh_pose[i * 8 + 5],
|
||||
mesh_pose[i * 8 + 6],
|
||||
mesh_pose[i * 8 + 3],
|
||||
)
|
||||
|
||||
obj_w_pose = wp.transform(obj_position, obj_quat)
|
||||
obj_w_pose_t = wp.transform_inverse(obj_w_pose)
|
||||
# transform point to mesh frame:
|
||||
# mesh_pt = T_inverse @ w_pt
|
||||
local_pt = wp.transform_point(obj_w_pose, in_pt)
|
||||
|
||||
if wp.mesh_query_point(
|
||||
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - local_pt)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - local_pt) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
else:
|
||||
dist = -1.0 * dist
|
||||
else:
|
||||
dist = in_rad
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
|
||||
mid_distance = dist
|
||||
# transform sphere -1
|
||||
if h_idx > 0 and mid_distance < sphere_0_distance:
|
||||
jump_distance = mid_distance
|
||||
j = int(0)
|
||||
sphere_temp = wp.transform_point(obj_w_pose, sphere_0)
|
||||
while j < sweep_steps:
|
||||
k0 = (
|
||||
1.0 - 0.5 * jump_distance / sphere_0_distance
|
||||
) # dist could be greater than sphere_0_distance here?
|
||||
sphere_int = k0 * local_pt + ((1.0 - k0) * sphere_temp)
|
||||
if wp.mesh_query_point(
|
||||
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - sphere_int)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - sphere_int) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
jump_distance += dist
|
||||
else:
|
||||
dist = max(-dist - in_rad, in_rad)
|
||||
|
||||
jump_distance += dist
|
||||
else:
|
||||
jump_distance += max_dist_buffer
|
||||
j += 1
|
||||
if jump_distance >= sphere_0_distance:
|
||||
j = int(sweep_steps)
|
||||
# transform sphere -1
|
||||
if h_idx < horizon - 1 and mid_distance < sphere_2_distance:
|
||||
jump_distance = mid_distance
|
||||
j = int(0)
|
||||
sphere_temp = wp.transform_point(obj_w_pose, sphere_2)
|
||||
while j < sweep_steps:
|
||||
k0 = (
|
||||
1.0 - 0.5 * jump_distance / sphere_2_distance
|
||||
) # dist could be greater than sphere_0_distance here?
|
||||
sphere_int = k0 * local_pt + (1.0 - k0) * sphere_temp
|
||||
|
||||
if wp.mesh_query_point(
|
||||
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - sphere_int)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - sphere_int) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
|
||||
jump_distance += dist
|
||||
else:
|
||||
dist = max(-dist - in_rad, in_rad)
|
||||
|
||||
jump_distance += dist
|
||||
else:
|
||||
jump_distance += max_dist_buffer
|
||||
|
||||
j += 1
|
||||
if jump_distance >= sphere_2_distance:
|
||||
j = int(sweep_steps)
|
||||
|
||||
i += 1
|
||||
# return
|
||||
if closest_distance == 0:
|
||||
if sparsity_idx[tid] == uint_zero:
|
||||
return
|
||||
sparsity_idx[tid] = uint_zero
|
||||
distance[tid] = 0.0
|
||||
if write_grad == 1:
|
||||
closest_pt[tid * 4 + 0] = 0.0
|
||||
closest_pt[tid * 4 + 1] = 0.0
|
||||
closest_pt[tid * 4 + 2] = 0.0
|
||||
return
|
||||
|
||||
if enable_speed_metric == 1 and (h_idx > 0 and h_idx < horizon - 1):
|
||||
# calculate sphere velocity and acceleration:
|
||||
norm_vel_vec = wp.vec3(0.0)
|
||||
sph_acc_vec = wp.vec3(0.0)
|
||||
sph_vel = wp.float(0.0)
|
||||
|
||||
# use central difference
|
||||
norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
|
||||
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
|
||||
# norm_vel_vec = -1.0 * norm_vel_vec
|
||||
# sph_acc_vec = -1.0 * sph_acc_vec
|
||||
sph_vel = wp.length(norm_vel_vec)
|
||||
|
||||
norm_vel_vec = norm_vel_vec / sph_vel
|
||||
|
||||
orth_proj = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) - wp.outer(
|
||||
norm_vel_vec, norm_vel_vec
|
||||
)
|
||||
|
||||
curvature_vec = orth_proj * (sph_acc_vec / (sph_vel * sph_vel))
|
||||
|
||||
closest_point = sph_vel * ((orth_proj * closest_point) - closest_distance * curvature_vec)
|
||||
|
||||
closest_distance = sph_vel * closest_distance
|
||||
|
||||
distance[tid] = weight[0] * closest_distance
|
||||
sparsity_idx[tid] = uint_one
|
||||
if write_grad == 1:
|
||||
# compute gradient:
|
||||
if closest_distance > 0.0:
|
||||
closest_distance = weight[0]
|
||||
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
|
||||
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
|
||||
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
|
||||
|
||||
|
||||
@wp.kernel
|
||||
def get_swept_closest_pt_batch_env(
|
||||
pt: wp.array(dtype=wp.vec4),
|
||||
@@ -307,7 +29,7 @@ def get_swept_closest_pt_batch_env(
|
||||
mesh_pose: wp.array(dtype=wp.float32),
|
||||
mesh_enable: wp.array(dtype=wp.uint8),
|
||||
n_env_mesh: wp.array(dtype=wp.int32),
|
||||
max_dist: wp.float32,
|
||||
max_dist: wp.array(dtype=wp.float32),
|
||||
write_grad: wp.uint8,
|
||||
batch_size: wp.int32,
|
||||
horizon: wp.int32,
|
||||
@@ -316,6 +38,7 @@ def get_swept_closest_pt_batch_env(
|
||||
sweep_steps: wp.uint8,
|
||||
enable_speed_metric: wp.uint8,
|
||||
env_query_idx: wp.array(dtype=wp.int32),
|
||||
use_batch_env: wp.uint8,
|
||||
):
|
||||
# we launch nspheres kernels
|
||||
# compute gradient here and return
|
||||
@@ -357,6 +80,7 @@ def get_swept_closest_pt_batch_env(
|
||||
sign = float(0.0)
|
||||
dist = float(0.0)
|
||||
dist_metric = float(0.0)
|
||||
euclidean_distance = float(0.0)
|
||||
cl_pt = wp.vec3(0.0)
|
||||
local_pt = wp.vec3(0.0)
|
||||
in_sphere = pt[b_idx * horizon * nspheres + (h_idx * nspheres) + sph_idx]
|
||||
@@ -374,7 +98,7 @@ def get_swept_closest_pt_batch_env(
|
||||
eta = activation_distance[0]
|
||||
in_rad += eta
|
||||
max_dist_buffer = float(0.0)
|
||||
max_dist_buffer = max_dist
|
||||
max_dist_buffer = max_dist[0]
|
||||
if (in_rad) > max_dist_buffer:
|
||||
max_dist_buffer += in_rad
|
||||
|
||||
@@ -396,7 +120,8 @@ def get_swept_closest_pt_batch_env(
|
||||
dis_length = float(0.0)
|
||||
jump_distance = float(0.0)
|
||||
mid_distance = float(0.0)
|
||||
env_idx = env_query_idx[b_idx]
|
||||
if use_batch_env:
|
||||
env_idx = env_query_idx[b_idx]
|
||||
i = max_nmesh * env_idx
|
||||
n_mesh = i + n_env_mesh[env_idx]
|
||||
obj_position = wp.vec3()
|
||||
@@ -423,26 +148,33 @@ def get_swept_closest_pt_batch_env(
|
||||
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - local_pt)
|
||||
delta = cl_pt - local_pt
|
||||
dis_length = wp.length(delta)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - local_pt) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
if dist == in_rad:
|
||||
cl_pt = sign * (delta) / (dist)
|
||||
else:
|
||||
cl_pt = sign * (delta) / dis_length
|
||||
euclidean_distance = dist
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
cl_pt = (1.0 / eta) * dist * cl_pt
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
else:
|
||||
dist = -1.0 * dist
|
||||
euclidean_distance = dist
|
||||
else:
|
||||
dist = max_dist_buffer
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
euclidean_distance = dist
|
||||
dist = max(euclidean_distance - in_rad, in_rad)
|
||||
|
||||
mid_distance = dist
|
||||
mid_distance = euclidean_distance
|
||||
# transform sphere -1
|
||||
if h_idx > 0 and mid_distance < sphere_0_distance:
|
||||
jump_distance = mid_distance
|
||||
@@ -457,24 +189,31 @@ def get_swept_closest_pt_batch_env(
|
||||
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - sphere_int)
|
||||
delta = cl_pt - sphere_int
|
||||
dis_length = wp.length(delta)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - sphere_int) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
if dist == in_rad:
|
||||
cl_pt = sign * (delta) / (dist)
|
||||
else:
|
||||
cl_pt = sign * (delta) / dis_length
|
||||
euclidean_distance = dist
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
cl_pt = (1.0 / eta) * dist * cl_pt
|
||||
|
||||
closest_distance += dist_metric
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
|
||||
closest_point += grad_vec
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
jump_distance += dist
|
||||
dist = max(euclidean_distance - in_rad, in_rad)
|
||||
jump_distance += euclidean_distance
|
||||
else:
|
||||
dist = max(-dist - in_rad, in_rad)
|
||||
jump_distance += dist
|
||||
euclidean_distance = dist
|
||||
jump_distance += euclidean_distance
|
||||
else:
|
||||
jump_distance += max_dist_buffer
|
||||
j += 1
|
||||
@@ -495,24 +234,30 @@ def get_swept_closest_pt_batch_env(
|
||||
mesh[i], sphere_int, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - sphere_int)
|
||||
delta = cl_pt - sphere_int
|
||||
dis_length = wp.length(delta)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - sphere_int) / dis_length
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
euclidean_distance = dist
|
||||
if dist == in_rad:
|
||||
cl_pt = sign * (delta) / (dist)
|
||||
else:
|
||||
cl_pt = sign * (delta) / dis_length
|
||||
# cl_pt = sign * (delta) / dis_length
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
cl_pt = (1.0 / eta) * dist * cl_pt
|
||||
closest_distance += dist_metric
|
||||
grad_vec = wp.transform_vector(obj_w_pose_t, cl_pt)
|
||||
|
||||
closest_point += grad_vec
|
||||
dist = max(dist - in_rad, in_rad)
|
||||
dist = max(euclidean_distance - in_rad, in_rad)
|
||||
jump_distance += dist
|
||||
|
||||
else:
|
||||
dist = max(-dist - in_rad, in_rad)
|
||||
|
||||
jump_distance += dist
|
||||
else:
|
||||
jump_distance += max_dist_buffer
|
||||
@@ -542,179 +287,54 @@ def get_swept_closest_pt_batch_env(
|
||||
|
||||
# use central difference
|
||||
norm_vel_vec = (0.5 / dt) * (sphere_2 - sphere_0)
|
||||
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
|
||||
# norm_vel_vec = -1.0 * norm_vel_vec
|
||||
# sph_acc_vec = -1.0 * sph_acc_vec
|
||||
sph_vel = wp.length(norm_vel_vec)
|
||||
if sph_vel > 1e-3:
|
||||
sph_acc_vec = (1.0 / (dt * dt)) * (sphere_0 + sphere_2 - 2.0 * in_pt)
|
||||
|
||||
norm_vel_vec = norm_vel_vec / sph_vel
|
||||
norm_vel_vec = norm_vel_vec * (1.0 / sph_vel)
|
||||
|
||||
orth_proj = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) - wp.outer(
|
||||
norm_vel_vec, norm_vel_vec
|
||||
)
|
||||
curvature_vec = sph_acc_vec / (sph_vel * sph_vel)
|
||||
|
||||
curvature_vec = orth_proj * (sph_acc_vec / (sph_vel * sph_vel))
|
||||
orth_proj = wp.mat33(0.0)
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
orth_proj[i, j] = -1.0 * norm_vel_vec[i] * norm_vel_vec[j]
|
||||
|
||||
closest_point = sph_vel * ((orth_proj * closest_point) - closest_distance * curvature_vec)
|
||||
orth_proj[0, 0] = orth_proj[0, 0] + 1.0
|
||||
orth_proj[1, 1] = orth_proj[1, 1] + 1.0
|
||||
orth_proj[2, 2] = orth_proj[2, 2] + 1.0
|
||||
|
||||
closest_distance = sph_vel * closest_distance
|
||||
orth_curv = wp.vec3(
|
||||
0.0, 0.0, 0.0
|
||||
) # closest_distance * (orth_proj @ curvature_vec) #wp.matmul(orth_proj, curvature_vec)
|
||||
orth_pt = wp.vec3(0.0, 0.0, 0.0) # orth_proj @ closest_point
|
||||
for i in range(3):
|
||||
orth_pt[i] = (
|
||||
orth_proj[i, 0] * closest_point[0]
|
||||
+ orth_proj[i, 1] * closest_point[1]
|
||||
+ orth_proj[i, 2] * closest_point[2]
|
||||
)
|
||||
|
||||
orth_curv[i] = closest_distance * (
|
||||
orth_proj[i, 0] * curvature_vec[0]
|
||||
+ orth_proj[i, 1] * curvature_vec[1]
|
||||
+ orth_proj[i, 2] * curvature_vec[2]
|
||||
)
|
||||
|
||||
closest_point = sph_vel * (orth_pt - orth_curv)
|
||||
|
||||
closest_distance = sph_vel * closest_distance
|
||||
|
||||
distance[tid] = weight[0] * closest_distance
|
||||
sparsity_idx[tid] = uint_one
|
||||
if write_grad == 1:
|
||||
# compute gradient:
|
||||
if closest_distance > 0.0:
|
||||
closest_distance = weight[0]
|
||||
closest_distance = weight[0]
|
||||
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
|
||||
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
|
||||
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
|
||||
|
||||
|
||||
@wp.kernel
|
||||
def get_closest_pt(
|
||||
pt: wp.array(dtype=wp.vec4),
|
||||
distance: wp.array(dtype=wp.float32), # this stores the output cost
|
||||
closest_pt: wp.array(dtype=wp.float32), # this stores the gradient
|
||||
sparsity_idx: wp.array(dtype=wp.uint8),
|
||||
weight: wp.array(dtype=wp.float32),
|
||||
activation_distance: wp.array(dtype=wp.float32), # eta threshold
|
||||
mesh: wp.array(dtype=wp.uint64),
|
||||
mesh_pose: wp.array(dtype=wp.float32),
|
||||
mesh_enable: wp.array(dtype=wp.uint8),
|
||||
n_env_mesh: wp.array(dtype=wp.int32),
|
||||
max_dist: wp.float32,
|
||||
write_grad: wp.uint8,
|
||||
batch_size: wp.int32,
|
||||
horizon: wp.int32,
|
||||
nspheres: wp.int32,
|
||||
max_nmesh: wp.int32,
|
||||
):
|
||||
# we launch nspheres kernels
|
||||
# compute gradient here and return
|
||||
# distance is negative outside and positive inside
|
||||
tid = wp.tid()
|
||||
n_mesh = int(0)
|
||||
b_idx = int(0)
|
||||
h_idx = int(0)
|
||||
sph_idx = int(0)
|
||||
|
||||
# env_idx = int(0)
|
||||
|
||||
b_idx = tid / (horizon * nspheres)
|
||||
|
||||
h_idx = (tid - (b_idx * (horizon * nspheres))) / nspheres
|
||||
sph_idx = tid - (b_idx * horizon * nspheres) - (h_idx * nspheres)
|
||||
if b_idx >= batch_size or h_idx >= horizon or sph_idx >= nspheres:
|
||||
return
|
||||
|
||||
face_index = int(0)
|
||||
face_u = float(0.0)
|
||||
face_v = float(0.0)
|
||||
sign = float(0.0)
|
||||
dist = float(0.0)
|
||||
grad_vec = wp.vec3(0.0)
|
||||
eta = float(0.05)
|
||||
dist_metric = float(0.0)
|
||||
|
||||
cl_pt = wp.vec3(0.0)
|
||||
local_pt = wp.vec3(0.0)
|
||||
in_sphere = pt[tid]
|
||||
in_pt = wp.vec3(in_sphere[0], in_sphere[1], in_sphere[2])
|
||||
in_rad = in_sphere[3]
|
||||
uint_zero = wp.uint8(0)
|
||||
uint_one = wp.uint8(1)
|
||||
if in_rad < 0.0:
|
||||
distance[tid] = 0.0
|
||||
if write_grad == 1 and sparsity_idx[tid] == uint_one:
|
||||
sparsity_idx[tid] = uint_zero
|
||||
closest_pt[tid * 4] = 0.0
|
||||
closest_pt[tid * 4 + 1] = 0.0
|
||||
closest_pt[tid * 4 + 2] = 0.0
|
||||
return
|
||||
|
||||
eta = activation_distance[0]
|
||||
in_rad += eta
|
||||
max_dist_buffer = float(0.0)
|
||||
max_dist_buffer = max_dist
|
||||
if in_rad > max_dist_buffer:
|
||||
max_dist_buffer += in_rad
|
||||
|
||||
# TODO: read vec4 and use first 3 for sphere position and last one for radius
|
||||
# in_pt = pt[tid]
|
||||
closest_distance = float(0.0)
|
||||
closest_point = wp.vec3(0.0)
|
||||
i = int(0)
|
||||
dis_length = float(0.0)
|
||||
# read env index:
|
||||
# env_idx = env_query_idx[b_idx]
|
||||
# read number of boxes in current environment:
|
||||
|
||||
# get start index
|
||||
i = int(0)
|
||||
n_mesh = n_env_mesh[0]
|
||||
obj_position = wp.vec3()
|
||||
|
||||
# mesh_idx = wp.uint64(0)
|
||||
while i < n_mesh:
|
||||
if mesh_enable[i] == uint_one:
|
||||
# transform point to mesh frame:
|
||||
# mesh_pt = T_inverse @ w_pt
|
||||
|
||||
# read object pose:
|
||||
obj_position[0] = mesh_pose[i * 8 + 0]
|
||||
obj_position[1] = mesh_pose[i * 8 + 1]
|
||||
obj_position[2] = mesh_pose[i * 8 + 2]
|
||||
obj_quat = wp.quaternion(
|
||||
mesh_pose[i * 8 + 4],
|
||||
mesh_pose[i * 8 + 5],
|
||||
mesh_pose[i * 8 + 6],
|
||||
mesh_pose[i * 8 + 3],
|
||||
)
|
||||
|
||||
obj_w_pose = wp.transform(obj_position, obj_quat)
|
||||
|
||||
local_pt = wp.transform_point(obj_w_pose, in_pt)
|
||||
# mesh_idx = mesh[i]
|
||||
if wp.mesh_query_point(
|
||||
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - local_pt)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - local_pt) / dis_length
|
||||
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
|
||||
i += 1
|
||||
|
||||
if closest_distance == 0:
|
||||
if sparsity_idx[tid] == uint_zero:
|
||||
return
|
||||
sparsity_idx[tid] = uint_zero
|
||||
distance[tid] = 0.0
|
||||
if write_grad == 1:
|
||||
closest_pt[tid * 4 + 0] = 0.0
|
||||
closest_pt[tid * 4 + 1] = 0.0
|
||||
closest_pt[tid * 4 + 2] = 0.0
|
||||
else:
|
||||
distance[tid] = weight[0] * closest_distance
|
||||
sparsity_idx[tid] = uint_one
|
||||
if write_grad == 1:
|
||||
# compute gradient:
|
||||
if closest_distance > 0.0:
|
||||
closest_distance = weight[0]
|
||||
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
|
||||
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
|
||||
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
|
||||
|
||||
|
||||
@wp.kernel
|
||||
def get_closest_pt_batch_env(
|
||||
pt: wp.array(dtype=wp.vec4),
|
||||
@@ -727,13 +347,15 @@ def get_closest_pt_batch_env(
|
||||
mesh_pose: wp.array(dtype=wp.float32),
|
||||
mesh_enable: wp.array(dtype=wp.uint8),
|
||||
n_env_mesh: wp.array(dtype=wp.int32),
|
||||
max_dist: wp.float32,
|
||||
max_dist: wp.array(dtype=wp.float32),
|
||||
write_grad: wp.uint8,
|
||||
batch_size: wp.int32,
|
||||
horizon: wp.int32,
|
||||
nspheres: wp.int32,
|
||||
max_nmesh: wp.int32,
|
||||
env_query_idx: wp.array(dtype=wp.int32),
|
||||
use_batch_env: wp.uint8,
|
||||
compute_esdf: wp.uint8,
|
||||
):
|
||||
# we launch nspheres kernels
|
||||
# compute gradient here and return
|
||||
@@ -779,8 +401,9 @@ def get_closest_pt_batch_env(
|
||||
|
||||
return
|
||||
eta = activation_distance[0]
|
||||
in_rad += eta
|
||||
max_dist_buffer = max_dist
|
||||
if compute_esdf != 1:
|
||||
in_rad += eta
|
||||
max_dist_buffer = max_dist[0]
|
||||
if (in_rad) > max_dist_buffer:
|
||||
max_dist_buffer += in_rad
|
||||
|
||||
@@ -791,7 +414,9 @@ def get_closest_pt_batch_env(
|
||||
dis_length = float(0.0)
|
||||
|
||||
# read env index:
|
||||
env_idx = env_query_idx[b_idx]
|
||||
if use_batch_env:
|
||||
env_idx = env_query_idx[b_idx]
|
||||
|
||||
# read number of boxes in current environment:
|
||||
|
||||
# get start index
|
||||
@@ -799,7 +424,9 @@ def get_closest_pt_batch_env(
|
||||
i = max_nmesh * env_idx
|
||||
n_mesh = i + n_env_mesh[env_idx]
|
||||
obj_position = wp.vec3()
|
||||
|
||||
max_dist_value = -1.0 * max_dist_buffer
|
||||
if compute_esdf == 1:
|
||||
closest_distance = max_dist_value
|
||||
while i < n_mesh:
|
||||
if mesh_enable[i] == uint_one:
|
||||
# transform point to mesh frame:
|
||||
@@ -822,21 +449,39 @@ def get_closest_pt_batch_env(
|
||||
mesh[i], local_pt, max_dist_buffer, sign, face_index, face_u, face_v
|
||||
):
|
||||
cl_pt = wp.mesh_eval_position(mesh[i], face_index, face_u, face_v)
|
||||
dis_length = wp.length(cl_pt - local_pt)
|
||||
dist = (-1.0 * dis_length * sign) + in_rad
|
||||
if dist > 0:
|
||||
cl_pt = sign * (cl_pt - local_pt) / dis_length
|
||||
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
grad_vec = (1.0 / eta) * dist * grad_vec
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
delta = cl_pt - local_pt
|
||||
dis_length = wp.length(delta)
|
||||
dist = -1.0 * dis_length * sign
|
||||
|
||||
if compute_esdf == 1:
|
||||
if dist > max_dist_value:
|
||||
max_dist_value = dist
|
||||
closest_distance = dist
|
||||
if write_grad == 1:
|
||||
cl_pt = sign * (delta) / dis_length
|
||||
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
|
||||
closest_point = grad_vec
|
||||
|
||||
else:
|
||||
dist = dist + in_rad
|
||||
|
||||
if dist > 0:
|
||||
if dist == in_rad:
|
||||
cl_pt = sign * (delta) / (dist)
|
||||
else:
|
||||
cl_pt = sign * (delta) / dis_length
|
||||
if dist > eta:
|
||||
dist_metric = dist - 0.5 * eta
|
||||
elif dist <= eta:
|
||||
dist_metric = (0.5 / eta) * (dist) * dist
|
||||
cl_pt = (1.0 / eta) * dist * cl_pt
|
||||
grad_vec = wp.transform_vector(wp.transform_inverse(obj_w_pose), cl_pt)
|
||||
|
||||
closest_distance += dist_metric
|
||||
closest_point += grad_vec
|
||||
i += 1
|
||||
|
||||
if closest_distance == 0:
|
||||
if closest_distance == 0 and compute_esdf != 1:
|
||||
if sparsity_idx[tid] == uint_zero:
|
||||
return
|
||||
sparsity_idx[tid] = uint_zero
|
||||
@@ -850,8 +495,7 @@ def get_closest_pt_batch_env(
|
||||
sparsity_idx[tid] = uint_one
|
||||
if write_grad == 1:
|
||||
# compute gradient:
|
||||
if closest_distance > 0.0:
|
||||
closest_distance = weight[0]
|
||||
closest_distance = weight[0]
|
||||
closest_pt[tid * 4 + 0] = closest_distance * closest_point[0]
|
||||
closest_pt[tid * 4 + 1] = closest_distance * closest_point[1]
|
||||
closest_pt[tid * 4 + 2] = closest_distance * closest_point[2]
|
||||
@@ -871,62 +515,43 @@ class SdfMeshWarpPy(torch.autograd.Function):
|
||||
mesh_pose_inverse,
|
||||
mesh_enable,
|
||||
n_env_mesh,
|
||||
max_dist=0.05,
|
||||
max_dist,
|
||||
env_query_idx=None,
|
||||
return_loss=False,
|
||||
compute_esdf=False,
|
||||
):
|
||||
b, h, n, _ = query_spheres.shape
|
||||
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
# launch
|
||||
wp.launch(
|
||||
kernel=get_closest_pt,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
max_dist,
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
else:
|
||||
wp.launch(
|
||||
kernel=get_closest_pt_batch_env,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
max_dist,
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
use_batch_env = False
|
||||
env_query_idx = n_env_mesh
|
||||
|
||||
wp.launch(
|
||||
kernel=get_closest_pt_batch_env,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
wp.from_torch(max_dist, dtype=wp.float32),
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
|
||||
use_batch_env,
|
||||
compute_esdf,
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
ctx.return_loss = return_loss
|
||||
ctx.save_for_backward(out_grad)
|
||||
return out_cost
|
||||
@@ -939,7 +564,22 @@ class SdfMeshWarpPy(torch.autograd.Function):
|
||||
grad_sph = r
|
||||
if ctx.return_loss:
|
||||
grad_sph = r * grad_output.unsqueeze(-1)
|
||||
return grad_sph, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
return (
|
||||
grad_sph,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class SweptSdfMeshWarpPy(torch.autograd.Function):
|
||||
@@ -957,69 +597,46 @@ class SweptSdfMeshWarpPy(torch.autograd.Function):
|
||||
mesh_pose_inverse,
|
||||
mesh_enable,
|
||||
n_env_mesh,
|
||||
max_dist,
|
||||
sweep_steps=1,
|
||||
enable_speed_metric=False,
|
||||
max_dist=0.05,
|
||||
env_query_idx=None,
|
||||
return_loss=False,
|
||||
):
|
||||
b, h, n, _ = query_spheres.shape
|
||||
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
wp.launch(
|
||||
kernel=get_swept_closest_pt,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(speed_dt),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
max_dist,
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
else:
|
||||
wp.launch(
|
||||
kernel=get_swept_closest_pt_batch_env,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(speed_dt),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
max_dist,
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
use_batch_env = False
|
||||
env_query_idx = n_env_mesh
|
||||
|
||||
wp.launch(
|
||||
kernel=get_swept_closest_pt_batch_env,
|
||||
dim=b * h * n,
|
||||
inputs=[
|
||||
wp.from_torch(query_spheres.detach().view(-1, 4), dtype=wp.vec4),
|
||||
wp.from_torch(out_cost.view(-1)),
|
||||
wp.from_torch(out_grad.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(sparsity_idx.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(weight),
|
||||
wp.from_torch(activation_distance),
|
||||
wp.from_torch(speed_dt),
|
||||
wp.from_torch(mesh_idx.view(-1), dtype=wp.uint64),
|
||||
wp.from_torch(mesh_pose_inverse.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(mesh_enable.view(-1), dtype=wp.uint8),
|
||||
wp.from_torch(n_env_mesh.view(-1), dtype=wp.int32),
|
||||
wp.from_torch(max_dist, dtype=wp.float32),
|
||||
query_spheres.requires_grad,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
mesh_idx.shape[1],
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
wp.from_torch(env_query_idx.view(-1), dtype=wp.int32),
|
||||
use_batch_env,
|
||||
],
|
||||
stream=wp.stream_from_torch(query_spheres.device),
|
||||
)
|
||||
ctx.return_loss = return_loss
|
||||
ctx.save_for_backward(out_grad)
|
||||
return out_cost
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.curobolib.geom import SdfSphereOBB, SdfSweptSphereOBB
|
||||
from curobo.geom.types import Cuboid, Mesh, Obstacle, WorldConfig, batch_tensor_cube
|
||||
from curobo.geom.types import Cuboid, Mesh, Obstacle, VoxelGrid, WorldConfig, batch_tensor_cube
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.math import Pose
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
@@ -39,10 +39,14 @@ class CollisionBuffer:
|
||||
def initialize_from_shape(cls, shape: torch.Size, tensor_args: TensorDeviceType):
|
||||
batch, horizon, n_spheres, _ = shape
|
||||
distance_buffer = torch.zeros(
|
||||
(batch, horizon, n_spheres), device=tensor_args.device, dtype=tensor_args.dtype
|
||||
(batch, horizon, n_spheres),
|
||||
device=tensor_args.device,
|
||||
dtype=tensor_args.collision_distance_dtype,
|
||||
)
|
||||
grad_distance_buffer = torch.zeros(
|
||||
(batch, horizon, n_spheres, 4), device=tensor_args.device, dtype=tensor_args.dtype
|
||||
(batch, horizon, n_spheres, 4),
|
||||
device=tensor_args.device,
|
||||
dtype=tensor_args.collision_gradient_dtype,
|
||||
)
|
||||
sparsity_idx = torch.zeros(
|
||||
(batch, horizon, n_spheres),
|
||||
@@ -54,10 +58,14 @@ class CollisionBuffer:
|
||||
def _update_from_shape(self, shape: torch.Size, tensor_args: TensorDeviceType):
|
||||
batch, horizon, n_spheres, _ = shape
|
||||
self.distance_buffer = torch.zeros(
|
||||
(batch, horizon, n_spheres), device=tensor_args.device, dtype=tensor_args.dtype
|
||||
(batch, horizon, n_spheres),
|
||||
device=tensor_args.device,
|
||||
dtype=tensor_args.collision_distance_dtype,
|
||||
)
|
||||
self.grad_distance_buffer = torch.zeros(
|
||||
(batch, horizon, n_spheres, 4), device=tensor_args.device, dtype=tensor_args.dtype
|
||||
(batch, horizon, n_spheres, 4),
|
||||
device=tensor_args.device,
|
||||
dtype=tensor_args.collision_gradient_dtype,
|
||||
)
|
||||
self.sparsity_index_buffer = torch.zeros(
|
||||
(batch, horizon, n_spheres),
|
||||
@@ -100,6 +108,7 @@ class CollisionQueryBuffer:
|
||||
primitive_collision_buffer: Optional[CollisionBuffer] = None
|
||||
mesh_collision_buffer: Optional[CollisionBuffer] = None
|
||||
blox_collision_buffer: Optional[CollisionBuffer] = None
|
||||
voxel_collision_buffer: Optional[CollisionBuffer] = None
|
||||
shape: Optional[torch.Size] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -110,6 +119,8 @@ class CollisionQueryBuffer:
|
||||
self.shape = self.mesh_collision_buffer.shape
|
||||
elif self.blox_collision_buffer is not None:
|
||||
self.shape = self.blox_collision_buffer.shape
|
||||
elif self.voxel_collision_buffer is not None:
|
||||
self.shape = self.voxel_collision_buffer.shape
|
||||
|
||||
def __mul__(self, scalar: float):
|
||||
if self.primitive_collision_buffer is not None:
|
||||
@@ -118,17 +129,27 @@ class CollisionQueryBuffer:
|
||||
self.mesh_collision_buffer = self.mesh_collision_buffer * scalar
|
||||
if self.blox_collision_buffer is not None:
|
||||
self.blox_collision_buffer = self.blox_collision_buffer * scalar
|
||||
if self.voxel_collision_buffer is not None:
|
||||
self.voxel_collision_buffer = self.voxel_collision_buffer * scalar
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
prim_buffer = mesh_buffer = blox_buffer = None
|
||||
prim_buffer = mesh_buffer = blox_buffer = voxel_buffer = None
|
||||
if self.primitive_collision_buffer is not None:
|
||||
prim_buffer = self.primitive_collision_buffer.clone()
|
||||
if self.mesh_collision_buffer is not None:
|
||||
mesh_buffer = self.mesh_collision_buffer.clone()
|
||||
if self.blox_collision_buffer is not None:
|
||||
blox_buffer = self.blox_collision_buffer.clone()
|
||||
return CollisionQueryBuffer(prim_buffer, mesh_buffer, blox_buffer, self.shape)
|
||||
if self.voxel_collision_buffer is not None:
|
||||
voxel_buffer = self.voxel_collision_buffer.clone()
|
||||
return CollisionQueryBuffer(
|
||||
prim_buffer,
|
||||
mesh_buffer,
|
||||
blox_buffer,
|
||||
voxel_collision_buffer=voxel_buffer,
|
||||
shape=self.shape,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def initialize_from_shape(
|
||||
@@ -137,14 +158,18 @@ class CollisionQueryBuffer:
|
||||
tensor_args: TensorDeviceType,
|
||||
collision_types: Dict[str, bool],
|
||||
):
|
||||
primitive_buffer = mesh_buffer = blox_buffer = None
|
||||
primitive_buffer = mesh_buffer = blox_buffer = voxel_buffer = None
|
||||
if "primitive" in collision_types and collision_types["primitive"]:
|
||||
primitive_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
if "mesh" in collision_types and collision_types["mesh"]:
|
||||
mesh_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
if "blox" in collision_types and collision_types["blox"]:
|
||||
blox_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
return CollisionQueryBuffer(primitive_buffer, mesh_buffer, blox_buffer)
|
||||
if "voxel" in collision_types and collision_types["voxel"]:
|
||||
voxel_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
return CollisionQueryBuffer(
|
||||
primitive_buffer, mesh_buffer, blox_buffer, voxel_collision_buffer=voxel_buffer
|
||||
)
|
||||
|
||||
def create_from_shape(
|
||||
self,
|
||||
@@ -160,8 +185,9 @@ class CollisionQueryBuffer:
|
||||
self.mesh_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
if "blox" in collision_types and collision_types["blox"]:
|
||||
self.blox_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
if "voxel" in collision_types and collision_types["voxel"]:
|
||||
self.voxel_collision_buffer = CollisionBuffer.initialize_from_shape(shape, tensor_args)
|
||||
self.shape = shape
|
||||
# return self
|
||||
|
||||
def update_buffer_shape(
|
||||
self,
|
||||
@@ -169,12 +195,10 @@ class CollisionQueryBuffer:
|
||||
tensor_args: TensorDeviceType,
|
||||
collision_types: Optional[Dict[str, bool]],
|
||||
):
|
||||
# print(shape, self.shape)
|
||||
# update buffers:
|
||||
assert len(shape) == 4 # shape is: batch, horizon, n_spheres, 4
|
||||
if self.shape is None: # buffers not initialized:
|
||||
self.create_from_shape(shape, tensor_args, collision_types)
|
||||
# print("Creating new memory", self.shape)
|
||||
else:
|
||||
# update buffers if shape doesn't match:
|
||||
# TODO: allow for dynamic change of collision_types
|
||||
@@ -185,6 +209,8 @@ class CollisionQueryBuffer:
|
||||
self.mesh_collision_buffer.update_buffer_shape(shape, tensor_args)
|
||||
if self.blox_collision_buffer is not None:
|
||||
self.blox_collision_buffer.update_buffer_shape(shape, tensor_args)
|
||||
if self.voxel_collision_buffer is not None:
|
||||
self.voxel_collision_buffer.update_buffer_shape(shape, tensor_args)
|
||||
self.shape = shape
|
||||
|
||||
def get_gradient_buffer(
|
||||
@@ -208,6 +234,12 @@ class CollisionQueryBuffer:
|
||||
current_buffer = blox_buffer.clone()
|
||||
else:
|
||||
current_buffer += blox_buffer
|
||||
if self.voxel_collision_buffer is not None:
|
||||
voxel_buffer = self.voxel_collision_buffer.grad_distance_buffer
|
||||
if current_buffer is None:
|
||||
current_buffer = voxel_buffer.clone()
|
||||
else:
|
||||
current_buffer += voxel_buffer
|
||||
|
||||
return current_buffer
|
||||
|
||||
@@ -221,6 +253,7 @@ class CollisionCheckerType(Enum):
|
||||
PRIMITIVE = "PRIMITIVE"
|
||||
BLOX = "BLOX"
|
||||
MESH = "MESH"
|
||||
VOXEL = "VOXEL"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -230,11 +263,13 @@ class WorldCollisionConfig:
|
||||
cache: Optional[Dict[Obstacle, int]] = None
|
||||
n_envs: int = 1
|
||||
checker_type: CollisionCheckerType = CollisionCheckerType.PRIMITIVE
|
||||
max_distance: float = 0.01
|
||||
max_distance: Union[torch.Tensor, float] = 0.01
|
||||
|
||||
def __post_init__(self):
|
||||
if self.world_model is not None and isinstance(self.world_model, list):
|
||||
self.n_envs = len(self.world_model)
|
||||
if isinstance(self.max_distance, float):
|
||||
self.max_distance = self.tensor_args.to_device([self.max_distance])
|
||||
|
||||
@staticmethod
|
||||
def load_from_dict(
|
||||
@@ -261,6 +296,8 @@ class WorldCollision(WorldCollisionConfig):
|
||||
if config is not None:
|
||||
WorldCollisionConfig.__init__(self, **vars(config))
|
||||
self.collision_types = {} # Use this dictionary to store collision types
|
||||
self._cache_voxelization = None
|
||||
self._cache_voxelization_collision_buffer = None
|
||||
|
||||
def load_collision_model(self, world_model: WorldConfig):
|
||||
raise NotImplementedError
|
||||
@@ -273,6 +310,8 @@ class WorldCollision(WorldCollisionConfig):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
"""
|
||||
Computes the signed distance via analytic function
|
||||
@@ -310,6 +349,7 @@ class WorldCollision(WorldCollisionConfig):
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -338,6 +378,118 @@ class WorldCollision(WorldCollisionConfig):
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_voxels_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
|
||||
voxel_size: float = 0.02,
|
||||
) -> Union[List[Cuboid], torch.Tensor]:
|
||||
new_grid = self.get_occupancy_in_bounding_box(cuboid, voxel_size)
|
||||
occupied = new_grid.get_occupied_voxels(0.0)
|
||||
return occupied
|
||||
|
||||
def clear_voxelization_cache(self):
|
||||
self._cache_voxelization = None
|
||||
|
||||
def update_cache_voxelization(self, new_grid: VoxelGrid):
|
||||
if (
|
||||
self._cache_voxelization is None
|
||||
or self._cache_voxelization.voxel_size != new_grid.voxel_size
|
||||
or self._cache_voxelization.dims != new_grid.dims
|
||||
):
|
||||
self._cache_voxelization = new_grid
|
||||
self._cache_voxelization.xyzr_tensor = self._cache_voxelization.create_xyzr_tensor(
|
||||
transform_to_origin=True, tensor_args=self.tensor_args
|
||||
)
|
||||
self._cache_voxelization_collision_buffer = CollisionQueryBuffer()
|
||||
xyzr = self._cache_voxelization.xyzr_tensor.view(-1, 1, 1, 4)
|
||||
|
||||
self._cache_voxelization_collision_buffer.update_buffer_shape(
|
||||
xyzr.shape,
|
||||
self.tensor_args,
|
||||
self.collision_types,
|
||||
)
|
||||
|
||||
def get_occupancy_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
|
||||
voxel_size: float = 0.02,
|
||||
) -> VoxelGrid:
|
||||
new_grid = VoxelGrid(
|
||||
name=cuboid.name, dims=cuboid.dims, pose=cuboid.pose, voxel_size=voxel_size
|
||||
)
|
||||
|
||||
self.update_cache_voxelization(new_grid)
|
||||
|
||||
xyzr = self._cache_voxelization.xyzr_tensor
|
||||
|
||||
xyzr = xyzr.view(-1, 1, 1, 4)
|
||||
|
||||
weight = self.tensor_args.to_device([1.0])
|
||||
act_distance = self.tensor_args.to_device([0.0])
|
||||
|
||||
d_sph = self.get_sphere_collision(
|
||||
xyzr,
|
||||
self._cache_voxelization_collision_buffer,
|
||||
weight,
|
||||
act_distance,
|
||||
)
|
||||
d_sph = d_sph.reshape(-1)
|
||||
new_grid.xyzr_tensor = self._cache_voxelization.xyzr_tensor
|
||||
new_grid.feature_tensor = d_sph
|
||||
return new_grid
|
||||
|
||||
def get_esdf_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
|
||||
voxel_size: float = 0.02,
|
||||
dtype=torch.float32,
|
||||
) -> VoxelGrid:
|
||||
new_grid = VoxelGrid(
|
||||
name=cuboid.name,
|
||||
dims=cuboid.dims,
|
||||
pose=cuboid.pose,
|
||||
voxel_size=voxel_size,
|
||||
feature_dtype=dtype,
|
||||
)
|
||||
|
||||
self.update_cache_voxelization(new_grid)
|
||||
|
||||
xyzr = self._cache_voxelization.xyzr_tensor
|
||||
voxel_shape = xyzr.shape
|
||||
xyzr = xyzr.view(-1, 1, 1, 4)
|
||||
|
||||
weight = self.tensor_args.to_device([1.0])
|
||||
|
||||
d_sph = self.get_sphere_distance(
|
||||
xyzr,
|
||||
self._cache_voxelization_collision_buffer,
|
||||
weight,
|
||||
self.max_distance,
|
||||
sum_collisions=False,
|
||||
compute_esdf=True,
|
||||
)
|
||||
|
||||
d_sph = d_sph.reshape(-1)
|
||||
voxel_grid = self._cache_voxelization
|
||||
voxel_grid.feature_tensor = d_sph
|
||||
|
||||
return voxel_grid
|
||||
|
||||
def get_mesh_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
|
||||
voxel_size: float = 0.02,
|
||||
) -> Mesh:
|
||||
voxels = self.get_voxels_in_bounding_box(cuboid, voxel_size)
|
||||
# voxels = voxels.cpu().numpy()
|
||||
# cuboids = [Cuboid(name="c_"+str(x), pose=[voxels[x,0],voxels[x,1],voxels[x,2], 1,0,0,0], dims=[voxel_size, voxel_size, voxel_size]) for x in range(voxels.shape[0])]
|
||||
# mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0]
|
||||
mesh = Mesh.from_pointcloud(
|
||||
voxels[:, :3].detach().cpu().numpy(),
|
||||
pitch=voxel_size * 1.1,
|
||||
)
|
||||
return mesh
|
||||
|
||||
|
||||
class WorldPrimitiveCollision(WorldCollision):
|
||||
"""World Oriented Bounding Box representation object
|
||||
@@ -354,6 +506,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
self._env_n_obbs = None
|
||||
self._env_obbs_names = None
|
||||
self._init_cache()
|
||||
|
||||
if self.world_model is not None:
|
||||
if isinstance(self.world_model, list):
|
||||
self.load_batch_collision_model(self.world_model)
|
||||
@@ -656,6 +809,8 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
|
||||
raise ValueError("Primitive Collision has no obstacles")
|
||||
@@ -673,6 +828,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
collision_query_buffer.primitive_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self._cube_tensor_list[0],
|
||||
self._cube_tensor_list[0],
|
||||
self._cube_tensor_list[1],
|
||||
@@ -687,6 +843,8 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
sum_collisions,
|
||||
compute_esdf,
|
||||
)
|
||||
|
||||
return dist
|
||||
@@ -699,6 +857,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
**kwargs,
|
||||
):
|
||||
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
|
||||
raise ValueError("Primitive Collision has no obstacles")
|
||||
@@ -717,6 +876,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
collision_query_buffer.primitive_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self._cube_tensor_list[0],
|
||||
self._cube_tensor_list[0],
|
||||
self._cube_tensor_list[1],
|
||||
@@ -728,7 +888,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
h,
|
||||
n,
|
||||
query_sphere.requires_grad,
|
||||
False,
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
)
|
||||
@@ -745,6 +905,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
"""
|
||||
Computes the signed distance via analytic function
|
||||
@@ -784,6 +945,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
sum_collisions,
|
||||
)
|
||||
|
||||
return dist
|
||||
@@ -836,7 +998,7 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
query_sphere.requires_grad,
|
||||
False,
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
)
|
||||
@@ -845,70 +1007,5 @@ class WorldPrimitiveCollision(WorldCollision):
|
||||
|
||||
def clear_cache(self):
|
||||
if self._cube_tensor_list is not None:
|
||||
self._cube_tensor_list[2][:] = 1
|
||||
self._cube_tensor_list[2][:] = 0
|
||||
self._env_n_obbs[:] = 0
|
||||
|
||||
def get_voxels_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid,
|
||||
voxel_size: float = 0.02,
|
||||
) -> Union[List[Cuboid], torch.Tensor]:
|
||||
bounds = cuboid.dims
|
||||
low = [-bounds[0], -bounds[1], -bounds[2]]
|
||||
high = [bounds[0], bounds[1], bounds[2]]
|
||||
trange = [h - l for l, h in zip(low, high)]
|
||||
|
||||
x = torch.linspace(
|
||||
-bounds[0], bounds[0], int(trange[0] // voxel_size) + 1, device=self.tensor_args.device
|
||||
)
|
||||
y = torch.linspace(
|
||||
-bounds[1], bounds[1], int(trange[1] // voxel_size) + 1, device=self.tensor_args.device
|
||||
)
|
||||
z = torch.linspace(
|
||||
-bounds[2], bounds[2], int(trange[2] // voxel_size) + 1, device=self.tensor_args.device
|
||||
)
|
||||
w, l, h = x.shape[0], y.shape[0], z.shape[0]
|
||||
xyz = (
|
||||
torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3)
|
||||
)
|
||||
|
||||
pose = Pose.from_list(cuboid.pose, tensor_args=self.tensor_args)
|
||||
xyz = pose.transform_points(xyz.contiguous())
|
||||
r = torch.zeros_like(xyz[:, 0:1]) + voxel_size
|
||||
xyzr = torch.cat([xyz, r], dim=1)
|
||||
xyzr = xyzr.reshape(-1, 1, 1, 4)
|
||||
collision_buffer = CollisionQueryBuffer()
|
||||
collision_buffer.update_buffer_shape(
|
||||
xyzr.shape,
|
||||
self.tensor_args,
|
||||
self.collision_types,
|
||||
)
|
||||
weight = self.tensor_args.to_device([1.0])
|
||||
act_distance = self.tensor_args.to_device([0.0])
|
||||
|
||||
d_sph = self.get_sphere_collision(
|
||||
xyzr,
|
||||
collision_buffer,
|
||||
weight,
|
||||
act_distance,
|
||||
)
|
||||
d_sph = d_sph.reshape(-1)
|
||||
xyzr = xyzr.reshape(-1, 4)
|
||||
# get occupied voxels:
|
||||
occupied = xyzr[d_sph > 0.0]
|
||||
return occupied
|
||||
|
||||
def get_mesh_in_bounding_box(
|
||||
self,
|
||||
cuboid: Cuboid,
|
||||
voxel_size: float = 0.02,
|
||||
) -> Mesh:
|
||||
voxels = self.get_voxels_in_bounding_box(cuboid, voxel_size)
|
||||
# voxels = voxels.cpu().numpy()
|
||||
# cuboids = [Cuboid(name="c_"+str(x), pose=[voxels[x,0],voxels[x,1],voxels[x,2], 1,0,0,0], dims=[voxel_size, voxel_size, voxel_size]) for x in range(voxels.shape[0])]
|
||||
# mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0]
|
||||
mesh = Mesh.from_pointcloud(
|
||||
voxels[:, :3].detach().cpu().numpy(),
|
||||
pitch=voxel_size * 1.1,
|
||||
)
|
||||
return mesh
|
||||
|
||||
@@ -176,6 +176,8 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
if "blox" not in self.collision_types or not self.collision_types["blox"]:
|
||||
return super().get_sphere_distance(
|
||||
@@ -185,6 +187,8 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
activation_distance,
|
||||
env_query_idx,
|
||||
return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
|
||||
d = self._get_blox_sdf(
|
||||
@@ -205,8 +209,13 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
activation_distance,
|
||||
env_query_idx,
|
||||
return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
d = d + d_base
|
||||
if compute_esdf:
|
||||
d = torch.maximum(d, d_base)
|
||||
else:
|
||||
d = d + d_base
|
||||
|
||||
return d
|
||||
|
||||
@@ -262,6 +271,7 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
if "blox" not in self.collision_types or not self.collision_types["blox"]:
|
||||
return super().get_swept_sphere_distance(
|
||||
@@ -274,6 +284,7 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
enable_speed_metric,
|
||||
env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
|
||||
d = self._get_blox_swept_sdf(
|
||||
@@ -301,6 +312,7 @@ class WorldBloxCollision(WorldMeshCollision):
|
||||
enable_speed_metric,
|
||||
env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
d = d + d_base
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
self._mesh_tensor_list[0][env_idx, :max_nmesh] = w_mid
|
||||
self._mesh_tensor_list[1][env_idx, :max_nmesh, :7] = w_inv_pose
|
||||
self._mesh_tensor_list[2][env_idx, :max_nmesh] = 1
|
||||
self._mesh_tensor_list[2][env_idx, max_nmesh:] = 0
|
||||
|
||||
self._env_mesh_names[env_idx][:max_nmesh] = name_list
|
||||
self._env_n_mesh[env_idx] = max_nmesh
|
||||
@@ -355,6 +356,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx=None,
|
||||
return_loss=False,
|
||||
compute_esdf=False,
|
||||
):
|
||||
d = SdfMeshWarpPy.apply(
|
||||
query_spheres,
|
||||
@@ -370,6 +372,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
self.max_distance,
|
||||
env_query_idx,
|
||||
return_loss,
|
||||
compute_esdf,
|
||||
)
|
||||
return d
|
||||
|
||||
@@ -397,9 +400,9 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
self._mesh_tensor_list[1],
|
||||
self._mesh_tensor_list[2],
|
||||
self._env_n_mesh,
|
||||
self.max_distance,
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
self.max_distance,
|
||||
env_query_idx,
|
||||
return_loss,
|
||||
)
|
||||
@@ -413,6 +416,8 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
# TODO: if no mesh object exist, call primitive
|
||||
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
|
||||
@@ -423,6 +428,8 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance=activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
|
||||
d = self._get_sdf(
|
||||
@@ -432,6 +439,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance=activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
|
||||
if "primitive" not in self.collision_types or not self.collision_types["primitive"]:
|
||||
@@ -443,8 +451,13 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance=activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
d_val = d.view(d_prim.shape) + d_prim
|
||||
if compute_esdf:
|
||||
d_val = torch.maximum(d.view(d_prim.shape), d_prim)
|
||||
else:
|
||||
d_val = d.view(d_prim.shape) + d_prim
|
||||
return d_val
|
||||
|
||||
def get_sphere_collision(
|
||||
@@ -455,6 +468,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx=None,
|
||||
return_loss=False,
|
||||
**kwargs,
|
||||
):
|
||||
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
|
||||
return super().get_sphere_collision(
|
||||
@@ -501,6 +515,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss: bool = False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
# log_warn("Swept: Mesh + Primitive Collision Checking is experimental")
|
||||
if "mesh" not in self.collision_types or not self.collision_types["mesh"]:
|
||||
@@ -514,6 +529,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
|
||||
d = self._get_swept_sdf(
|
||||
@@ -540,6 +556,7 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
d_val = d.view(d_prim.shape) + d_prim
|
||||
|
||||
@@ -602,4 +619,11 @@ class WorldMeshCollision(WorldPrimitiveCollision):
|
||||
self._wp_mesh_cache = {}
|
||||
if self._mesh_tensor_list is not None:
|
||||
self._mesh_tensor_list[2][:] = 0
|
||||
if self._env_n_mesh is not None:
|
||||
self._env_n_mesh[:] = 0
|
||||
if self._env_mesh_names is not None:
|
||||
self._env_mesh_names = [
|
||||
[None for _ in range(self.cache["mesh"])] for _ in range(self.n_envs)
|
||||
]
|
||||
|
||||
super().clear_cache()
|
||||
|
||||
699
src/curobo/geom/sdf/world_voxel.py
Normal file
699
src/curobo/geom/sdf/world_voxel.py
Normal file
@@ -0,0 +1,699 @@
|
||||
#
|
||||
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
# Standard Library
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Third Party
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.curobolib.geom import SdfSphereVoxel, SdfSweptSphereVoxel
|
||||
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollisionConfig
|
||||
from curobo.geom.sdf.world_mesh import WorldMeshCollision
|
||||
from curobo.geom.types import VoxelGrid, WorldConfig
|
||||
from curobo.types.math import Pose
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
|
||||
|
||||
class WorldVoxelCollision(WorldMeshCollision):
|
||||
"""Voxel grid representation of World, with each voxel containing Euclidean Signed Distance."""
|
||||
|
||||
def __init__(self, config: WorldCollisionConfig):
|
||||
self._env_n_voxels = None
|
||||
self._voxel_tensor_list = None
|
||||
self._env_voxel_names = None
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
def _init_cache(self):
|
||||
if (
|
||||
self.cache is not None
|
||||
and "voxel" in self.cache
|
||||
and self.cache["voxel"] not in [None, 0]
|
||||
):
|
||||
self._create_voxel_cache(self.cache["voxel"])
|
||||
|
||||
def _create_voxel_cache(self, voxel_cache: Dict[str, Any]):
|
||||
n_layers = voxel_cache["layers"]
|
||||
dims = voxel_cache["dims"]
|
||||
voxel_size = voxel_cache["voxel_size"]
|
||||
feature_dtype = voxel_cache["feature_dtype"]
|
||||
n_voxels = int(
|
||||
math.floor(dims[0] / voxel_size)
|
||||
* math.floor(dims[1] / voxel_size)
|
||||
* math.floor(dims[2] / voxel_size)
|
||||
)
|
||||
|
||||
voxel_params = torch.zeros(
|
||||
(self.n_envs, n_layers, 4),
|
||||
dtype=self.tensor_args.dtype,
|
||||
device=self.tensor_args.device,
|
||||
)
|
||||
voxel_pose = torch.zeros(
|
||||
(self.n_envs, n_layers, 8),
|
||||
dtype=self.tensor_args.dtype,
|
||||
device=self.tensor_args.device,
|
||||
)
|
||||
voxel_pose[..., 3] = 1.0
|
||||
voxel_enable = torch.zeros(
|
||||
(self.n_envs, n_layers), dtype=torch.uint8, device=self.tensor_args.device
|
||||
)
|
||||
self._env_n_voxels = torch.zeros(
|
||||
(self.n_envs), device=self.tensor_args.device, dtype=torch.int32
|
||||
)
|
||||
voxel_features = torch.zeros(
|
||||
(self.n_envs, n_layers, n_voxels, 1),
|
||||
device=self.tensor_args.device,
|
||||
dtype=feature_dtype,
|
||||
)
|
||||
|
||||
self._voxel_tensor_list = [voxel_params, voxel_pose, voxel_enable, voxel_features]
|
||||
self.collision_types["voxel"] = True
|
||||
self._env_voxel_names = [[None for _ in range(n_layers)] for _ in range(self.n_envs)]
|
||||
|
||||
def load_collision_model(
|
||||
self, world_model: WorldConfig, env_idx=0, fix_cache_reference: bool = False
|
||||
):
|
||||
self._load_collision_model_in_cache(
|
||||
world_model, env_idx, fix_cache_reference=fix_cache_reference
|
||||
)
|
||||
return super().load_collision_model(
|
||||
world_model, env_idx=env_idx, fix_cache_reference=fix_cache_reference
|
||||
)
|
||||
|
||||
def _load_collision_model_in_cache(
|
||||
self, world_config: WorldConfig, env_idx: int = 0, fix_cache_reference: bool = False
|
||||
):
|
||||
"""TODO:
|
||||
|
||||
_extended_summary_
|
||||
|
||||
Args:
|
||||
world_config: _description_
|
||||
env_idx: _description_
|
||||
fix_cache_reference: _description_
|
||||
"""
|
||||
voxel_objs = world_config.voxel
|
||||
max_obs = len(voxel_objs)
|
||||
self.world_model = world_config
|
||||
if max_obs < 1:
|
||||
log_info("No Voxel objs")
|
||||
return
|
||||
if self._voxel_tensor_list is None or self._voxel_tensor_list[0].shape[1] < max_obs:
|
||||
if not fix_cache_reference:
|
||||
log_info("Creating Voxel cache" + str(max_obs))
|
||||
self._create_voxel_cache(
|
||||
{
|
||||
"layers": max_obs,
|
||||
"dims": voxel_objs[0].dims,
|
||||
"voxel_size": voxel_objs[0].voxel_size,
|
||||
"feature_dtype": voxel_objs[0].feature_dtype,
|
||||
}
|
||||
)
|
||||
else:
|
||||
log_error("number of OBB is larger than collision cache, create larger cache.")
|
||||
|
||||
# load as a batch:
|
||||
pose_batch = [c.pose for c in voxel_objs]
|
||||
dims_batch = [c.dims for c in voxel_objs]
|
||||
names_batch = [c.name for c in voxel_objs]
|
||||
size_batch = [c.voxel_size for c in voxel_objs]
|
||||
voxel_batch = self._batch_tensor_voxel(pose_batch, dims_batch, size_batch)
|
||||
self._voxel_tensor_list[0][env_idx, :max_obs, :] = voxel_batch[0]
|
||||
self._voxel_tensor_list[1][env_idx, :max_obs, :7] = voxel_batch[1]
|
||||
|
||||
self._voxel_tensor_list[2][env_idx, :max_obs] = 1 # enabling obstacle
|
||||
|
||||
self._voxel_tensor_list[2][env_idx, max_obs:] = 0 # disabling obstacle
|
||||
|
||||
# copy voxel grid features:
|
||||
|
||||
self._env_n_voxels[env_idx] = max_obs
|
||||
self._env_voxel_names[env_idx][:max_obs] = names_batch
|
||||
self.collision_types["voxel"] = True
|
||||
|
||||
def _batch_tensor_voxel(
|
||||
self, pose: List[List[float]], dims: List[float], voxel_size: List[float]
|
||||
):
|
||||
w_T_b = Pose.from_batch_list(pose, tensor_args=self.tensor_args)
|
||||
b_T_w = w_T_b.inverse()
|
||||
dims_t = torch.as_tensor(
|
||||
np.array(dims), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
size_t = torch.as_tensor(
|
||||
np.array(voxel_size), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
).unsqueeze(-1)
|
||||
params_t = torch.cat([dims_t, size_t], dim=-1)
|
||||
|
||||
voxel_list = [params_t, b_T_w.get_pose_vector()]
|
||||
return voxel_list
|
||||
|
||||
def load_batch_collision_model(self, world_config_list: List[WorldConfig]):
|
||||
"""Load voxel grid for batched environments
|
||||
|
||||
_extended_summary_
|
||||
|
||||
Args:
|
||||
world_config_list: _description_
|
||||
|
||||
Returns:
|
||||
_description_
|
||||
"""
|
||||
log_error("Not Implemented")
|
||||
# First find largest number of cuboid:
|
||||
c_len = []
|
||||
pose_batch = []
|
||||
dims_batch = []
|
||||
names_batch = []
|
||||
vsize_batch = []
|
||||
for i in world_config_list:
|
||||
c = i.cuboid
|
||||
if c is not None:
|
||||
c_len.append(len(c))
|
||||
pose_batch.extend([i.pose for i in c])
|
||||
dims_batch.extend([i.dims for i in c])
|
||||
names_batch.extend([i.name for i in c])
|
||||
vsize_batch.extend([i.voxel_size for i in c])
|
||||
else:
|
||||
c_len.append(0)
|
||||
|
||||
max_obs = max(c_len)
|
||||
if max_obs < 1:
|
||||
log_warn("No obbs found")
|
||||
return
|
||||
# check if number of environments is same as config:
|
||||
reset_buffers = False
|
||||
if self._env_n_voxels is not None and len(world_config_list) != len(self._env_n_voxels):
|
||||
log_warn(
|
||||
"env_n_voxels is not same as world_config_list, reloading collision buffers (breaks CG)"
|
||||
)
|
||||
reset_buffers = True
|
||||
self.n_envs = len(world_config_list)
|
||||
self._env_n_voxels = torch.zeros(
|
||||
(self.n_envs), device=self.tensor_args.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
if self._voxel_tensor_list is not None and self._voxel_tensor_list[0].shape[1] < max_obs:
|
||||
log_warn(
|
||||
"number of obbs is greater than buffer, reloading collision buffers (breaks CG)"
|
||||
)
|
||||
reset_buffers = True
|
||||
# create cache if does not exist:
|
||||
if self._voxel_tensor_list is None or reset_buffers:
|
||||
log_info("Creating Obb cache" + str(max_obs))
|
||||
self._create_obb_cache(max_obs)
|
||||
|
||||
# load obstacles:
|
||||
## load data into gpu:
|
||||
voxel_batch = self._batch_tensor_voxel(pose_batch, dims_batch, vsize_batch)
|
||||
c_start = 0
|
||||
for i in range(len(self._env_n_voxels)):
|
||||
if c_len[i] > 0:
|
||||
# load obb:
|
||||
self._voxel_tensor_list[0][i, : c_len[i], :] = voxel_batch[0][
|
||||
c_start : c_start + c_len[i]
|
||||
]
|
||||
self._voxel_tensor_list[1][i, : c_len[i], :7] = voxel_batch[1][
|
||||
c_start : c_start + c_len[i]
|
||||
]
|
||||
self._voxel_tensor_list[2][i, : c_len[i]] = 1
|
||||
self._env_voxel_names[i][: c_len[i]] = names_batch[c_start : c_start + c_len[i]]
|
||||
self._voxel_tensor_list[2][i, c_len[i] :] = 0
|
||||
c_start += c_len[i]
|
||||
self._env_n_voxels[:] = torch.as_tensor(
|
||||
c_len, dtype=torch.int32, device=self.tensor_args.device
|
||||
)
|
||||
self.collision_types["voxel"] = True
|
||||
|
||||
return super().load_batch_collision_model(world_config_list)
|
||||
|
||||
def enable_obstacle(
|
||||
self,
|
||||
name: str,
|
||||
enable: bool = True,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
if self._env_voxel_names is not None and name in self._env_voxel_names[env_idx]:
|
||||
self.enable_voxel(enable, name, None, env_idx)
|
||||
else:
|
||||
return super().enable_obstacle(name, enable, env_idx)
|
||||
|
||||
def enable_voxel(
|
||||
self,
|
||||
enable: bool = True,
|
||||
name: Optional[str] = None,
|
||||
env_obj_idx: Optional[torch.Tensor] = None,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
"""Update obstacle dimensions
|
||||
|
||||
Args:
|
||||
obj_dims (torch.Tensor): [dim.x,dim.y, dim.z], give as [b,3]
|
||||
obj_idx (torch.Tensor or int):
|
||||
|
||||
"""
|
||||
if env_obj_idx is not None:
|
||||
self._voxel_tensor_list[2][env_obj_idx] = int(enable) # enable == 1
|
||||
else:
|
||||
# find index of given name:
|
||||
obs_idx = self.get_voxel_idx(name, env_idx)
|
||||
|
||||
self._voxel_tensor_list[2][env_idx, obs_idx] = int(enable)
|
||||
|
||||
def update_obstacle_pose(
|
||||
self,
|
||||
name: str,
|
||||
w_obj_pose: Pose,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
if self._env_voxel_names is not None and name in self._env_voxel_names[env_idx]:
|
||||
self.update_voxel_pose(name=name, w_obj_pose=w_obj_pose, env_idx=env_idx)
|
||||
else:
|
||||
log_error("obstacle not found in OBB world model: " + name)
|
||||
|
||||
def update_voxel_data(self, new_voxel: VoxelGrid, env_idx: int = 0):
|
||||
obs_idx = self.get_voxel_idx(new_voxel.name, env_idx)
|
||||
self._voxel_tensor_list[3][env_idx, obs_idx, :, :] = new_voxel.feature_tensor.view(
|
||||
new_voxel.feature_tensor.shape[0], -1
|
||||
).to(dtype=self._voxel_tensor_list[3].dtype)
|
||||
self._voxel_tensor_list[0][env_idx, obs_idx, :3] = self.tensor_args.to_device(
|
||||
new_voxel.dims
|
||||
)
|
||||
self._voxel_tensor_list[0][env_idx, obs_idx, 3] = new_voxel.voxel_size
|
||||
self._voxel_tensor_list[1][env_idx, obs_idx, :7] = (
|
||||
Pose.from_list(new_voxel.pose, self.tensor_args).inverse().get_pose_vector()
|
||||
)
|
||||
self._voxel_tensor_list[2][env_idx, obs_idx] = int(True)
|
||||
|
||||
def update_voxel_features(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
env_obj_idx: Optional[torch.Tensor] = None,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
"""Update pose of a specific objects.
|
||||
This also updates the signed distance grid to account for the updated object pose.
|
||||
Args:
|
||||
obj_w_pose: Pose
|
||||
obj_idx:
|
||||
"""
|
||||
if env_obj_idx is not None:
|
||||
self._voxel_tensor_list[3][env_obj_idx, :] = features.to(
|
||||
dtype=self._voxel_tensor_list[3].dtype
|
||||
)
|
||||
else:
|
||||
obs_idx = self.get_voxel_idx(name, env_idx)
|
||||
self._voxel_tensor_list[3][env_idx, obs_idx, :] = features.to(
|
||||
dtype=self._voxel_tensor_list[3].dtype
|
||||
)
|
||||
|
||||
def update_voxel_pose(
|
||||
self,
|
||||
w_obj_pose: Optional[Pose] = None,
|
||||
obj_w_pose: Optional[Pose] = None,
|
||||
name: Optional[str] = None,
|
||||
env_obj_idx: Optional[torch.Tensor] = None,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
"""Update pose of a specific objects.
|
||||
This also updates the signed distance grid to account for the updated object pose.
|
||||
Args:
|
||||
obj_w_pose: Pose
|
||||
obj_idx:
|
||||
"""
|
||||
obj_w_pose = self._get_obstacle_poses(w_obj_pose, obj_w_pose)
|
||||
if env_obj_idx is not None:
|
||||
self._voxel_tensor_list[1][env_obj_idx, :7] = obj_w_pose.get_pose_vector()
|
||||
else:
|
||||
obs_idx = self.get_voxel_idx(name, env_idx)
|
||||
self._voxel_tensor_list[1][env_idx, obs_idx, :7] = obj_w_pose.get_pose_vector()
|
||||
|
||||
def get_voxel_idx(
|
||||
self,
|
||||
name: str,
|
||||
env_idx: int = 0,
|
||||
) -> int:
|
||||
if name not in self._env_voxel_names[env_idx]:
|
||||
log_error("Obstacle with name: " + name + " not found in current world", exc_info=True)
|
||||
return self._env_voxel_names[env_idx].index(name)
|
||||
|
||||
def get_voxel_grid(
|
||||
self,
|
||||
name: str,
|
||||
env_idx: int = 0,
|
||||
):
|
||||
obs_idx = self.get_voxel_idx(name, env_idx)
|
||||
voxel_params = np.round(
|
||||
self._voxel_tensor_list[0][env_idx, obs_idx, :].cpu().numpy().astype(np.float64), 6
|
||||
).tolist()
|
||||
voxel_pose = Pose(
|
||||
position=self._voxel_tensor_list[1][env_idx, obs_idx, :3],
|
||||
quaternion=self._voxel_tensor_list[1][env_idx, obs_idx, 3:7],
|
||||
)
|
||||
voxel_features = self._voxel_tensor_list[3][env_idx, obs_idx, :]
|
||||
voxel_grid = VoxelGrid(
|
||||
name=name,
|
||||
dims=voxel_params[:3],
|
||||
pose=voxel_pose.to_list(),
|
||||
voxel_size=voxel_params[3],
|
||||
feature_tensor=voxel_features,
|
||||
)
|
||||
return voxel_grid
|
||||
|
||||
def get_sphere_distance(
|
||||
self,
|
||||
query_sphere,
|
||||
collision_query_buffer: CollisionQueryBuffer,
|
||||
weight: torch.Tensor,
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
sum_collisions: bool = True,
|
||||
compute_esdf: bool = False,
|
||||
):
|
||||
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
|
||||
return super().get_sphere_distance(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
|
||||
b, h, n, _ = query_sphere.shape # This can be read from collision query buffer
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
dist = SdfSphereVoxel.apply(
|
||||
query_sphere,
|
||||
collision_query_buffer.voxel_collision_buffer.distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
query_sphere.requires_grad,
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
sum_collisions,
|
||||
compute_esdf,
|
||||
)
|
||||
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_sphere_distance(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight=weight,
|
||||
activation_distance=activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
compute_esdf=compute_esdf,
|
||||
)
|
||||
if compute_esdf:
|
||||
d_val = torch.maximum(dist.view(d_prim.shape), d_prim)
|
||||
else:
|
||||
d_val = d_val.view(d_prim.shape) + d_prim
|
||||
|
||||
return d_val
|
||||
|
||||
def get_sphere_collision(
|
||||
self,
|
||||
query_sphere,
|
||||
collision_query_buffer: CollisionQueryBuffer,
|
||||
weight: torch.Tensor,
|
||||
activation_distance: torch.Tensor,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
**kwargs,
|
||||
):
|
||||
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
|
||||
return super().get_sphere_collision(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
)
|
||||
|
||||
if return_loss:
|
||||
raise ValueError("cannot return loss for classification, use get_sphere_distance")
|
||||
b, h, n, _ = query_sphere.shape
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
dist = SdfSphereVoxel.apply(
|
||||
query_sphere,
|
||||
collision_query_buffer.voxel_collision_buffer.distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
query_sphere.requires_grad,
|
||||
False,
|
||||
use_batch_env,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_sphere_collision(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight,
|
||||
activation_distance=activation_distance,
|
||||
env_query_idx=env_query_idx,
|
||||
return_loss=return_loss,
|
||||
)
|
||||
d_val = dist.view(d_prim.shape) + d_prim
|
||||
return d_val
|
||||
|
||||
def get_swept_sphere_distance(
|
||||
self,
|
||||
query_sphere,
|
||||
collision_query_buffer: CollisionQueryBuffer,
|
||||
weight: torch.Tensor,
|
||||
activation_distance: torch.Tensor,
|
||||
speed_dt: torch.Tensor,
|
||||
sweep_steps: int,
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
sum_collisions: bool = True,
|
||||
):
|
||||
"""
|
||||
Computes the signed distance via analytic function
|
||||
Args:
|
||||
tensor_sphere: b, n, 4
|
||||
"""
|
||||
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
|
||||
return super().get_swept_sphere_distance(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight=weight,
|
||||
env_query_idx=env_query_idx,
|
||||
sweep_steps=sweep_steps,
|
||||
activation_distance=activation_distance,
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
b, h, n, _ = query_sphere.shape
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
|
||||
dist = SdfSweptSphereVoxel.apply(
|
||||
query_sphere,
|
||||
collision_query_buffer.voxel_collision_buffer.distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
speed_dt,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
query_sphere.requires_grad,
|
||||
True,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
sum_collisions,
|
||||
)
|
||||
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_swept_sphere_distance(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight=weight,
|
||||
env_query_idx=env_query_idx,
|
||||
sweep_steps=sweep_steps,
|
||||
activation_distance=activation_distance,
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
sum_collisions=sum_collisions,
|
||||
)
|
||||
|
||||
d_val = dist.view(d_prim.shape) + d_prim
|
||||
return d_val
|
||||
|
||||
def get_swept_sphere_collision(
|
||||
self,
|
||||
query_sphere,
|
||||
collision_query_buffer: CollisionQueryBuffer,
|
||||
weight: torch.Tensor,
|
||||
activation_distance: torch.Tensor,
|
||||
speed_dt: torch.Tensor,
|
||||
sweep_steps: int,
|
||||
enable_speed_metric=False,
|
||||
env_query_idx: Optional[torch.Tensor] = None,
|
||||
return_loss=False,
|
||||
):
|
||||
"""
|
||||
Computes the signed distance via analytic function
|
||||
Args:
|
||||
tensor_sphere: b, n, 4
|
||||
"""
|
||||
if "voxel" not in self.collision_types or not self.collision_types["voxel"]:
|
||||
return super().get_swept_sphere_collision(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight=weight,
|
||||
env_query_idx=env_query_idx,
|
||||
sweep_steps=sweep_steps,
|
||||
activation_distance=activation_distance,
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
)
|
||||
if return_loss:
|
||||
raise ValueError("cannot return loss for classify, use get_swept_sphere_distance")
|
||||
b, h, n, _ = query_sphere.shape
|
||||
|
||||
use_batch_env = True
|
||||
if env_query_idx is None:
|
||||
use_batch_env = False
|
||||
env_query_idx = self._env_n_voxels
|
||||
dist = SdfSweptSphereVoxel.apply(
|
||||
query_sphere,
|
||||
collision_query_buffer.voxel_collision_buffer.distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.grad_distance_buffer,
|
||||
collision_query_buffer.voxel_collision_buffer.sparsity_index_buffer,
|
||||
weight,
|
||||
activation_distance,
|
||||
self.max_distance,
|
||||
speed_dt,
|
||||
self._voxel_tensor_list[3],
|
||||
self._voxel_tensor_list[0],
|
||||
self._voxel_tensor_list[1],
|
||||
self._voxel_tensor_list[2],
|
||||
self._env_n_voxels,
|
||||
env_query_idx,
|
||||
self._voxel_tensor_list[0].shape[1],
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
sweep_steps,
|
||||
enable_speed_metric,
|
||||
query_sphere.requires_grad,
|
||||
False,
|
||||
use_batch_env,
|
||||
return_loss,
|
||||
True,
|
||||
)
|
||||
if (
|
||||
"primitive" not in self.collision_types
|
||||
or not self.collision_types["primitive"]
|
||||
or "mesh" not in self.collision_types
|
||||
or not self.collision_types["mesh"]
|
||||
):
|
||||
return dist
|
||||
d_prim = super().get_swept_sphere_collision(
|
||||
query_sphere,
|
||||
collision_query_buffer,
|
||||
weight=weight,
|
||||
env_query_idx=env_query_idx,
|
||||
sweep_steps=sweep_steps,
|
||||
activation_distance=activation_distance,
|
||||
speed_dt=speed_dt,
|
||||
enable_speed_metric=enable_speed_metric,
|
||||
return_loss=return_loss,
|
||||
)
|
||||
d_val = dist.view(d_prim.shape) + d_prim
|
||||
return d_val
|
||||
|
||||
def clear_cache(self):
|
||||
if self._voxel_tensor_list is not None:
|
||||
self._voxel_tensor_list[2][:] = 0
|
||||
self._voxel_tensor_list[-1][:] = -1.0 * self.max_distance
|
||||
self._env_n_voxels[:] = 0
|
||||
@@ -18,6 +18,7 @@ import warp as wp
|
||||
# CuRobo
|
||||
from curobo.curobolib.kinematics import rotation_matrix_to_quaternion
|
||||
from curobo.util.logger import log_error
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.warp import init_warp
|
||||
|
||||
|
||||
@@ -27,11 +28,11 @@ def transform_points(
|
||||
if out_points is None:
|
||||
out_points = torch.zeros((points.shape[0], 3), device=points.device, dtype=points.dtype)
|
||||
if out_gp is None:
|
||||
out_gp = torch.zeros((position.shape[0], 3), device=position.device)
|
||||
out_gp = torch.zeros((position.shape[0], 3), device=position.device, dtype=points.dtype)
|
||||
if out_gq is None:
|
||||
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device)
|
||||
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device, dtype=points.dtype)
|
||||
if out_gpt is None:
|
||||
out_gpt = torch.zeros((points.shape[0], 3), device=position.device)
|
||||
out_gpt = torch.zeros((points.shape[0], 3), device=position.device, dtype=points.dtype)
|
||||
out_points = TransformPoint.apply(
|
||||
position, quaternion, points, out_points, out_gp, out_gq, out_gpt
|
||||
)
|
||||
@@ -46,18 +47,20 @@ def batch_transform_points(
|
||||
(points.shape[0], points.shape[1], 3), device=points.device, dtype=points.dtype
|
||||
)
|
||||
if out_gp is None:
|
||||
out_gp = torch.zeros((position.shape[0], 3), device=position.device)
|
||||
out_gp = torch.zeros((position.shape[0], 3), device=position.device, dtype=points.dtype)
|
||||
if out_gq is None:
|
||||
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device)
|
||||
out_gq = torch.zeros((quaternion.shape[0], 4), device=quaternion.device, dtype=points.dtype)
|
||||
if out_gpt is None:
|
||||
out_gpt = torch.zeros((points.shape[0], points.shape[1], 3), device=position.device)
|
||||
out_gpt = torch.zeros(
|
||||
(points.shape[0], points.shape[1], 3), device=position.device, dtype=points.dtype
|
||||
)
|
||||
out_points = BatchTransformPoint.apply(
|
||||
position, quaternion, points, out_points, out_gp, out_gq, out_gpt
|
||||
)
|
||||
return out_points
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def get_inv_transform(w_rot_c, w_trans_c):
|
||||
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
||||
c_rot_w = w_rot_c.transpose(-1, -2)
|
||||
@@ -65,7 +68,7 @@ def get_inv_transform(w_rot_c, w_trans_c):
|
||||
return c_rot_w, c_trans_w
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def transform_point_inverse(point, rot, trans):
|
||||
# type: (Tensor, Tensor, Tensor) -> Tensor
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# Standard Library
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
@@ -564,6 +565,68 @@ class PointCloud(Obstacle):
|
||||
return new_spheres
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoxelGrid(Obstacle):
|
||||
dims: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0])
|
||||
voxel_size: float = 0.02 # meters
|
||||
feature_tensor: Optional[torch.Tensor] = None
|
||||
xyzr_tensor: Optional[torch.Tensor] = None
|
||||
feature_dtype: torch.dtype = torch.float32
|
||||
|
||||
def __post_init__(self):
|
||||
if self.feature_tensor is not None:
|
||||
self.feature_dtype = self.feature_tensor.dtype
|
||||
|
||||
def create_xyzr_tensor(
|
||||
self, transform_to_origin: bool = False, tensor_args: TensorDeviceType = TensorDeviceType()
|
||||
):
|
||||
bounds = self.dims
|
||||
low = [-bounds[0] / 2, -bounds[1] / 2, -bounds[2] / 2]
|
||||
high = [bounds[0] / 2, bounds[1] / 2, bounds[2] / 2]
|
||||
trange = [h - l for l, h in zip(low, high)]
|
||||
x = torch.linspace(
|
||||
low[0], high[0], int(math.floor(trange[0] / self.voxel_size)), device=tensor_args.device
|
||||
)
|
||||
y = torch.linspace(
|
||||
low[1], high[1], int(math.floor(trange[1] / self.voxel_size)), device=tensor_args.device
|
||||
)
|
||||
z = torch.linspace(
|
||||
low[2], high[2], int(math.floor(trange[2] / self.voxel_size)), device=tensor_args.device
|
||||
)
|
||||
w, l, h = x.shape[0], y.shape[0], z.shape[0]
|
||||
xyz = (
|
||||
torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3)
|
||||
)
|
||||
|
||||
if transform_to_origin:
|
||||
pose = Pose.from_list(self.pose, tensor_args=tensor_args)
|
||||
xyz = pose.transform_points(xyz.contiguous())
|
||||
r = torch.zeros_like(xyz[:, 0:1]) + (self.voxel_size * 0.5)
|
||||
xyzr = torch.cat([xyz, r], dim=1)
|
||||
return xyzr
|
||||
|
||||
def get_occupied_voxels(self, feature_threshold: Optional[float] = None):
|
||||
if feature_threshold is None:
|
||||
feature_threshold = -1.0 * self.voxel_size
|
||||
if self.xyzr_tensor is None or self.feature_tensor is None:
|
||||
log_error("Feature tensor or xyzr tensor is empty")
|
||||
xyzr = self.xyzr_tensor.clone()
|
||||
xyzr[:, 3] = self.feature_tensor
|
||||
occupied = xyzr[self.feature_tensor > feature_threshold]
|
||||
return occupied
|
||||
|
||||
def clone(self):
|
||||
return VoxelGrid(
|
||||
name=self.name,
|
||||
pose=self.pose.copy(),
|
||||
dims=self.dims.copy(),
|
||||
feature_tensor=self.feature_tensor.clone() if self.feature_tensor is not None else None,
|
||||
xyzr_tensor=self.xyzr_tensor.clone() if self.xyzr_tensor is not None else None,
|
||||
feature_dtype=self.feature_dtype,
|
||||
voxel_size=self.voxel_size,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorldConfig(Sequence):
|
||||
"""Representation of World for use in CuRobo."""
|
||||
@@ -586,25 +649,13 @@ class WorldConfig(Sequence):
|
||||
#: BloxMap obstacle.
|
||||
blox: Optional[List[BloxMap]] = None
|
||||
|
||||
voxel: Optional[List[VoxelGrid]] = None
|
||||
|
||||
#: List of all obstacles in world.
|
||||
objects: Optional[List[Obstacle]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# create objects list:
|
||||
if self.objects is None:
|
||||
self.objects = []
|
||||
if self.sphere is not None:
|
||||
self.objects += self.sphere
|
||||
if self.cuboid is not None:
|
||||
self.objects += self.cuboid
|
||||
if self.capsule is not None:
|
||||
self.objects += self.capsule
|
||||
if self.mesh is not None:
|
||||
self.objects += self.mesh
|
||||
if self.blox is not None:
|
||||
self.objects += self.blox
|
||||
if self.cylinder is not None:
|
||||
self.objects += self.cylinder
|
||||
if self.sphere is None:
|
||||
self.sphere = []
|
||||
if self.cuboid is None:
|
||||
@@ -617,6 +668,18 @@ class WorldConfig(Sequence):
|
||||
self.cylinder = []
|
||||
if self.blox is None:
|
||||
self.blox = []
|
||||
if self.voxel is None:
|
||||
self.voxel = []
|
||||
if self.objects is None:
|
||||
self.objects = (
|
||||
self.sphere
|
||||
+ self.cuboid
|
||||
+ self.capsule
|
||||
+ self.mesh
|
||||
+ self.cylinder
|
||||
+ self.blox
|
||||
+ self.voxel
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.objects)
|
||||
@@ -632,6 +695,7 @@ class WorldConfig(Sequence):
|
||||
capsule=self.capsule.copy() if self.capsule is not None else None,
|
||||
cylinder=self.cylinder.copy() if self.cylinder is not None else None,
|
||||
blox=self.blox.copy() if self.blox is not None else None,
|
||||
voxel=self.voxel.copy() if self.voxel is not None else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -642,6 +706,7 @@ class WorldConfig(Sequence):
|
||||
mesh = None
|
||||
blox = None
|
||||
cylinder = None
|
||||
voxel = None
|
||||
# load yaml:
|
||||
if "cuboid" in data_dict.keys():
|
||||
cuboid = [Cuboid(name=x, **data_dict["cuboid"][x]) for x in data_dict["cuboid"]]
|
||||
@@ -655,6 +720,8 @@ class WorldConfig(Sequence):
|
||||
cylinder = [Cylinder(name=x, **data_dict["cylinder"][x]) for x in data_dict["cylinder"]]
|
||||
if "blox" in data_dict.keys():
|
||||
blox = [BloxMap(name=x, **data_dict["blox"][x]) for x in data_dict["blox"]]
|
||||
if "voxel" in data_dict.keys():
|
||||
voxel = [VoxelGrid(name=x, **data_dict["voxel"][x]) for x in data_dict["voxel"]]
|
||||
|
||||
return WorldConfig(
|
||||
cuboid=cuboid,
|
||||
@@ -663,6 +730,7 @@ class WorldConfig(Sequence):
|
||||
cylinder=cylinder,
|
||||
mesh=mesh,
|
||||
blox=blox,
|
||||
voxel=voxel,
|
||||
)
|
||||
|
||||
# load world config as obbs: convert all types to obbs
|
||||
@@ -688,6 +756,10 @@ class WorldConfig(Sequence):
|
||||
|
||||
if current_world.mesh is not None and len(current_world.mesh) > 0:
|
||||
mesh_obb = [x.get_cuboid() for x in current_world.mesh]
|
||||
|
||||
if current_world.voxel is not None and len(current_world.voxel) > 0:
|
||||
log_error("VoxelGrid cannot be converted to obb world")
|
||||
|
||||
return WorldConfig(
|
||||
cuboid=cuboid_obb + sphere_obb + capsule_obb + cylinder_obb + mesh_obb + blox_obb
|
||||
)
|
||||
@@ -714,6 +786,8 @@ class WorldConfig(Sequence):
|
||||
for i in range(len(current_world.blox)):
|
||||
if current_world.blox[i].mesh is not None:
|
||||
blox_obb.append(current_world.blox[i].get_mesh(process=process))
|
||||
if current_world.voxel is not None and len(current_world.voxel) > 0:
|
||||
log_error("VoxelGrid cannot be converted to mesh world")
|
||||
|
||||
return WorldConfig(
|
||||
mesh=current_world.mesh
|
||||
@@ -750,6 +824,7 @@ class WorldConfig(Sequence):
|
||||
return WorldConfig(
|
||||
mesh=current_world.mesh + sphere_obb + capsule_obb + cylinder_obb + blox_obb,
|
||||
cuboid=cuboid_obb,
|
||||
voxel=current_world.voxel,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -822,6 +897,8 @@ class WorldConfig(Sequence):
|
||||
self.cylinder.append(obstacle)
|
||||
elif isinstance(obstacle, Capsule):
|
||||
self.capsule.append(obstacle)
|
||||
elif isinstance(obstacle, VoxelGrid):
|
||||
self.voxel.append(obstacle)
|
||||
else:
|
||||
ValueError("Obstacle type not supported")
|
||||
self.objects.append(obstacle)
|
||||
|
||||
@@ -33,6 +33,7 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import JointState, RobotConfig, State
|
||||
from curobo.util.logger import log_info, log_warn
|
||||
from curobo.util.sample_lib import HaltonGenerator
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.trajectory import InterpolateType, get_interpolated_trajectory
|
||||
from curobo.util_file import (
|
||||
get_robot_configs_path,
|
||||
@@ -1029,7 +1030,7 @@ class GraphPlanBase(GraphConfig):
|
||||
pass
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(dynamic=True)
|
||||
def get_unique_nodes(dist_node: torch.Tensor, nodes: torch.Tensor, node_distance: float):
|
||||
node_flag = dist_node <= node_distance
|
||||
dist_node[node_flag] = 0.0
|
||||
@@ -1042,7 +1043,7 @@ def get_unique_nodes(dist_node: torch.Tensor, nodes: torch.Tensor, node_distance
|
||||
return unique_nodes, n_inv
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(force_jit=True, dynamic=True)
|
||||
def add_new_nodes_jit(
|
||||
nodes, new_nodes, flag, cat_buffer, path, idx, i: int, dof: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
@@ -1066,7 +1067,7 @@ def add_new_nodes_jit(
|
||||
return path, node_set, new_nodes.shape[0]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(force_jit=True, dynamic=True)
|
||||
def add_all_nodes_jit(
|
||||
nodes, cat_buffer, path, i: int, dof: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
@@ -1084,20 +1085,20 @@ def add_all_nodes_jit(
|
||||
return path, node_set, nodes.shape[0]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(force_jit=True, dynamic=True)
|
||||
def compute_distance_norm_jit(pt, batch_pts, distance_weight):
|
||||
vec = (batch_pts - pt) * distance_weight
|
||||
dist = torch.norm(vec, dim=-1)
|
||||
return dist
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(dynamic=True)
|
||||
def compute_distance_jit(pt, batch_pts, distance_weight):
|
||||
vec = (batch_pts - pt) * distance_weight
|
||||
return vec
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(dynamic=True)
|
||||
def compute_rotation_frame_jit(
|
||||
x_start: torch.Tensor, x_goal: torch.Tensor, rot_frame_col: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
@@ -1114,7 +1115,7 @@ def compute_rotation_frame_jit(
|
||||
return C
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(force_jit=True, dynamic=True)
|
||||
def biased_vertex_projection_jit(
|
||||
x_start,
|
||||
x_goal,
|
||||
@@ -1144,7 +1145,7 @@ def biased_vertex_projection_jit(
|
||||
return x_samples
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(force_jit=True, dynamic=True)
|
||||
def cat_xc_jit(x, n: int):
|
||||
c = x[:, 0:1] * 0.0
|
||||
xc_search = torch.cat((x, c), dim=1)[:n, :]
|
||||
|
||||
@@ -20,12 +20,13 @@ import torch.autograd.profiler as profiler
|
||||
# CuRobo
|
||||
from curobo.curobolib.opt import LBFGScu
|
||||
from curobo.opt.newton.newton_base import NewtonOptBase, NewtonOptConfig
|
||||
from curobo.util.logger import log_warn
|
||||
from curobo.util.logger import log_info
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
# kernel for l-bfgs:
|
||||
# @torch.jit.script
|
||||
def compute_step_direction(
|
||||
@get_torch_jit_decorator()
|
||||
def jit_lbfgs_compute_step_direction(
|
||||
alpha_buffer,
|
||||
rho_buffer,
|
||||
y_buffer,
|
||||
@@ -35,6 +36,8 @@ def compute_step_direction(
|
||||
epsilon: float,
|
||||
stable_mode: bool = True,
|
||||
):
|
||||
grad_q = grad_q.transpose(-1, -2)
|
||||
|
||||
# m = 15 (int)
|
||||
# y_buffer, s_buffer: m x b x 175
|
||||
# q, grad_q: b x 175
|
||||
@@ -60,12 +63,39 @@ def compute_step_direction(
|
||||
return -1.0 * r
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def jit_lbfgs_update_buffers(
|
||||
q, grad_q, s_buffer, y_buffer, rho_buffer, x_0, grad_0, stable_mode: bool
|
||||
):
|
||||
grad_q = grad_q.transpose(-1, -2)
|
||||
q = q.unsqueeze(-1)
|
||||
|
||||
y = grad_q - grad_0
|
||||
s = q - x_0
|
||||
rho = 1 / (y.transpose(-1, -2) @ s)
|
||||
if stable_mode:
|
||||
rho = torch.nan_to_num(rho, 0.0, 0.0, 0.0)
|
||||
s_buffer[0] = s
|
||||
s_buffer[:] = torch.roll(s_buffer, -1, dims=0)
|
||||
y_buffer[0] = y
|
||||
y_buffer[:] = torch.roll(y_buffer, -1, dims=0) # .copy_(y_buff)
|
||||
|
||||
rho_buffer[0] = rho
|
||||
|
||||
rho_buffer[:] = torch.roll(rho_buffer, -1, dims=0)
|
||||
|
||||
x_0.copy_(q)
|
||||
grad_0.copy_(grad_q)
|
||||
return s_buffer, y_buffer, rho_buffer, x_0, grad_0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LBFGSOptConfig(NewtonOptConfig):
|
||||
history: int = 10
|
||||
epsilon: float = 0.01
|
||||
use_cuda_kernel: bool = True
|
||||
stable_mode: bool = True
|
||||
use_shared_buffers_kernel: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
return super().__post_init__()
|
||||
@@ -77,11 +107,15 @@ class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
|
||||
if config is not None:
|
||||
LBFGSOptConfig.__init__(self, **vars(config))
|
||||
NewtonOptBase.__init__(self)
|
||||
if self.d_opt >= 1024 or self.history > 15:
|
||||
log_warn("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>15")
|
||||
if (
|
||||
self.d_opt >= 1024
|
||||
or self.history > 31
|
||||
or ((self.d_opt * self.history + 33) * 4 >= 48000)
|
||||
):
|
||||
log_info("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>15")
|
||||
self.use_cuda_kernel = False
|
||||
if self.history >= self.d_opt:
|
||||
log_warn("LBFGS: history >= d_opt, reducing history to d_opt-1")
|
||||
if self.history > self.d_opt:
|
||||
log_info("LBFGS: history >= d_opt, reducing history to d_opt-1")
|
||||
self.history = self.d_opt - 1
|
||||
|
||||
@profiler.record_function("lbfgs/reset")
|
||||
@@ -130,7 +164,7 @@ class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
|
||||
def _get_step_direction(self, cost, q, grad_q):
|
||||
if self.use_cuda_kernel:
|
||||
with profiler.record_function("lbfgs/fused"):
|
||||
dq = LBFGScu._call_cuda(
|
||||
dq = LBFGScu.apply(
|
||||
self.step_q_buffer,
|
||||
self.rho_buffer,
|
||||
self.y_buffer,
|
||||
@@ -141,13 +175,14 @@ class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
|
||||
self.grad_0,
|
||||
self.epsilon,
|
||||
self.stable_mode,
|
||||
self.use_shared_buffers_kernel,
|
||||
)
|
||||
|
||||
else:
|
||||
grad_q = grad_q.transpose(-1, -2)
|
||||
q = q.unsqueeze(-1)
|
||||
|
||||
self._update_buffers(q, grad_q)
|
||||
dq = compute_step_direction(
|
||||
|
||||
dq = jit_lbfgs_compute_step_direction(
|
||||
self.alpha_buffer,
|
||||
self.rho_buffer,
|
||||
self.y_buffer,
|
||||
@@ -177,6 +212,23 @@ class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
|
||||
return -1.0 * r
|
||||
|
||||
def _update_buffers(self, q, grad_q):
|
||||
if True:
|
||||
self.s_buffer, self.y_buffer, self.rho_buffer, self.x_0, self.grad_0 = (
|
||||
jit_lbfgs_update_buffers(
|
||||
q,
|
||||
grad_q,
|
||||
self.s_buffer,
|
||||
self.y_buffer,
|
||||
self.rho_buffer,
|
||||
self.x_0,
|
||||
self.grad_0,
|
||||
self.stable_mode,
|
||||
)
|
||||
)
|
||||
return
|
||||
grad_q = grad_q.transpose(-1, -2)
|
||||
q = q.unsqueeze(-1)
|
||||
|
||||
y = grad_q - self.grad_0
|
||||
s = q - self.x_0
|
||||
rho = 1 / (y.transpose(-1, -2) @ s)
|
||||
|
||||
@@ -26,6 +26,7 @@ from curobo.opt.opt_base import Optimizer, OptimizerConfig
|
||||
from curobo.rollout.dynamics_model.integration_utils import build_fd_matrix
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.tensor import T_BDOF, T_BHDOF_float, T_BHValue_float, T_BValue_float, T_HDOF_float
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
class LineSearchType(Enum):
|
||||
@@ -108,6 +109,7 @@ class NewtonOptBase(Optimizer, NewtonOptConfig):
|
||||
self.action_horizon, device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
).unsqueeze(0)
|
||||
self._temporal_mat += eye_mat
|
||||
self.rollout_fn.sum_horizon = True
|
||||
|
||||
def reset_cuda_graph(self):
|
||||
if self.cu_opt_graph is not None:
|
||||
@@ -222,11 +224,14 @@ class NewtonOptBase(Optimizer, NewtonOptConfig):
|
||||
self.n_problems * self.num_particles, self.action_horizon, self.rollout_fn.d_action
|
||||
)
|
||||
trajectories = self.rollout_fn(x_in) # x_n = (batch*line_search_scale) x horizon x d_action
|
||||
cost = torch.sum(
|
||||
trajectories.costs.view(self.n_problems, self.num_particles, self.horizon),
|
||||
dim=-1,
|
||||
keepdim=True,
|
||||
)
|
||||
if len(trajectories.costs.shape) == 2:
|
||||
cost = torch.sum(
|
||||
trajectories.costs.view(self.n_problems, self.num_particles, self.horizon),
|
||||
dim=-1,
|
||||
keepdim=True,
|
||||
)
|
||||
else:
|
||||
cost = trajectories.costs.view(self.n_problems, self.num_particles, 1)
|
||||
g_x = cost.backward(gradient=self.l_vec, retain_graph=False)
|
||||
g_x = x_n.grad.detach()
|
||||
return (
|
||||
@@ -542,7 +547,7 @@ class NewtonOptBase(Optimizer, NewtonOptConfig):
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def get_x_set_jit(step_vec, x, alpha_list, action_lows, action_highs):
|
||||
# step_direction = step_direction.detach()
|
||||
x_set = torch.clamp(x.unsqueeze(-2) + alpha_list * step_vec, action_lows, action_highs)
|
||||
@@ -550,7 +555,7 @@ def get_x_set_jit(step_vec, x, alpha_list, action_lows, action_highs):
|
||||
return x_set
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def _armijo_line_search_tail_jit(c, g_x, step_direction, c_1, alpha_list, c_idx, x_set, d_opt):
|
||||
c_0 = c[:, 0:1]
|
||||
g_0 = g_x[:, 0:1]
|
||||
@@ -581,7 +586,7 @@ def _armijo_line_search_tail_jit(c, g_x, step_direction, c_1, alpha_list, c_idx,
|
||||
return (best_x, best_c, best_grad)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def _wolfe_search_tail_jit(c, g_x, x_set, m, d_opt: int):
|
||||
b, h, _ = x_set.shape
|
||||
g_x = g_x.view(b * h, -1)
|
||||
@@ -593,7 +598,7 @@ def _wolfe_search_tail_jit(c, g_x, x_set, m, d_opt: int):
|
||||
return (best_x, best_c, best_grad)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def scale_action(dx, action_step_max):
|
||||
scale_value = torch.max(torch.abs(dx) / action_step_max, dim=-1, keepdim=True)[0]
|
||||
scale_value = torch.clamp(scale_value, 1.0)
|
||||
@@ -601,7 +606,7 @@ def scale_action(dx, action_step_max):
|
||||
return dx_scaled
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def check_convergence(
|
||||
best_iteration: torch.Tensor, current_iteration: torch.Tensor, last_best: int
|
||||
) -> bool:
|
||||
|
||||
@@ -16,7 +16,15 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.opt.particle.parallel_mppi import CovType, ParallelMPPI, ParallelMPPIConfig
|
||||
from curobo.opt.particle.parallel_mppi import (
|
||||
CovType,
|
||||
ParallelMPPI,
|
||||
ParallelMPPIConfig,
|
||||
Trajectory,
|
||||
jit_blend_mean,
|
||||
)
|
||||
from curobo.opt.particle.particle_opt_base import SampleMode
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,14 +46,72 @@ class ParallelES(ParallelMPPI, ParallelESConfig):
|
||||
)
|
||||
# get the new means from here
|
||||
# use Evolutionary Strategy Mean here:
|
||||
new_mean = jit_blend_mean(self.mean_action, new_mean, self.step_size_mean)
|
||||
return new_mean
|
||||
|
||||
def _exp_util_from_costs(self, costs):
|
||||
total_costs = self._compute_total_cost(costs)
|
||||
w = self._exp_util(total_costs)
|
||||
return w
|
||||
|
||||
def _exp_util(self, total_costs):
|
||||
w = calc_exp(total_costs)
|
||||
return w
|
||||
|
||||
def _compute_mean_covariance(self, costs, actions):
|
||||
w = self._exp_util_from_costs(costs)
|
||||
w = w.unsqueeze(-1).unsqueeze(-1)
|
||||
new_mean = self._compute_mean(w, actions)
|
||||
new_cov = self._compute_covariance(w, actions)
|
||||
self._update_cov_scale(new_cov)
|
||||
|
||||
@torch.jit.script
|
||||
return new_mean, new_cov
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_distribution(self, trajectories: Trajectory):
|
||||
costs = trajectories.costs
|
||||
actions = trajectories.actions
|
||||
|
||||
# Let's reshape to n_problems now:
|
||||
|
||||
# first find the means before doing exponential utility:
|
||||
|
||||
# Update best action
|
||||
if self.sample_mode == SampleMode.BEST:
|
||||
w = self._exp_util_from_costs(costs)
|
||||
|
||||
best_idx = torch.argmax(w, dim=-1)
|
||||
self.best_traj.copy_(actions[self.problem_col, best_idx])
|
||||
|
||||
if self.store_rollouts and self.visual_traj is not None:
|
||||
total_costs = self._compute_total_cost(costs)
|
||||
vis_seq = getattr(trajectories.state, self.visual_traj)
|
||||
top_values, top_idx = torch.topk(total_costs, 20, dim=1)
|
||||
self.top_values = top_values
|
||||
self.top_idx = top_idx
|
||||
top_trajs = torch.index_select(vis_seq, 0, top_idx[0])
|
||||
for i in range(1, top_idx.shape[0]):
|
||||
trajs = torch.index_select(
|
||||
vis_seq, 0, top_idx[i] + (self.particles_per_problem * i)
|
||||
)
|
||||
top_trajs = torch.cat((top_trajs, trajs), dim=0)
|
||||
if self.top_trajs is None or top_trajs.shape != self.top_trajs:
|
||||
self.top_trajs = top_trajs
|
||||
else:
|
||||
self.top_trajs.copy_(top_trajs)
|
||||
|
||||
if not self.update_cov:
|
||||
w = w.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
new_mean = self._compute_mean(w, actions)
|
||||
else:
|
||||
new_mean, new_cov = self._compute_mean_covariance(costs, actions)
|
||||
self.cov_action.copy_(new_cov)
|
||||
|
||||
self.mean_action.copy_(new_mean)
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def calc_exp(
|
||||
total_costs,
|
||||
):
|
||||
@@ -58,7 +124,7 @@ def calc_exp(
|
||||
return w
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def compute_es_mean(
|
||||
w, actions, mean_action, full_inv_cov, num_particles: int, learning_rate: float
|
||||
):
|
||||
|
||||
@@ -33,6 +33,7 @@ from curobo.types.robot import State
|
||||
from curobo.util.logger import log_info
|
||||
from curobo.util.sample_lib import HaltonSampleLib, SampleConfig, SampleLib
|
||||
from curobo.util.tensor_util import copy_tensor
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
class BaseActionType(Enum):
|
||||
@@ -187,11 +188,43 @@ class ParallelMPPI(ParticleOptBase, ParallelMPPIConfig):
|
||||
# w = torch.softmax((-1.0 / self.beta) * total_costs, dim=-1)
|
||||
return w
|
||||
|
||||
def _exp_util_from_costs(self, costs):
|
||||
w = jit_calculate_exp_util_from_costs(costs, self.gamma_seq, self.beta)
|
||||
return w
|
||||
|
||||
def _compute_mean(self, w, actions):
|
||||
# get the new means from here
|
||||
new_mean = torch.sum(w * actions, dim=-3)
|
||||
new_mean = jit_blend_mean(self.mean_action, new_mean, self.step_size_mean)
|
||||
return new_mean
|
||||
|
||||
def _compute_mean_covariance(self, costs, actions):
|
||||
if self.cov_type == CovType.FULL_A:
|
||||
log_error("Not implemented")
|
||||
if self.cov_type == CovType.DIAG_A:
|
||||
new_mean, new_cov, new_scale_tril = jit_mean_cov_diag_a(
|
||||
costs,
|
||||
actions,
|
||||
self.gamma_seq,
|
||||
self.mean_action,
|
||||
self.cov_action,
|
||||
self.step_size_mean,
|
||||
self.step_size_cov,
|
||||
self.kappa,
|
||||
self.beta,
|
||||
)
|
||||
self.scale_tril.copy_(new_scale_tril)
|
||||
# self._update_cov_scale(new_cov)
|
||||
|
||||
else:
|
||||
w = self._exp_util_from_costs(costs)
|
||||
w = w.unsqueeze(-1).unsqueeze(-1)
|
||||
new_mean = self._compute_mean(w, actions)
|
||||
new_cov = self._compute_covariance(w, actions)
|
||||
self._update_cov_scale(new_cov)
|
||||
|
||||
return new_mean, new_cov
|
||||
|
||||
def _compute_covariance(self, w, actions):
|
||||
if not self.update_cov:
|
||||
return
|
||||
@@ -200,27 +233,13 @@ class ParallelMPPI(ParticleOptBase, ParallelMPPIConfig):
|
||||
if self.cov_type == CovType.SIGMA_I:
|
||||
delta_actions = actions - self.mean_action.unsqueeze(-3)
|
||||
|
||||
# weighted_delta = w * (delta ** 2).T
|
||||
# cov_update = torch.ean(torch.sum(weighted_delta.T, dim=0))
|
||||
# print(cov_update.shape, self.cov_action)
|
||||
weighted_delta = w * (delta_actions**2)
|
||||
cov_update = torch.mean(
|
||||
torch.sum(torch.sum(weighted_delta, dim=-2), dim=-1), dim=-1, keepdim=True
|
||||
)
|
||||
# raise NotImplementedError("Need to implement covariance update of form sigma*I")
|
||||
|
||||
elif self.cov_type == CovType.DIAG_A:
|
||||
# Diagonal covariance of size AxA
|
||||
# n, b, h, d = delta_actions.shape
|
||||
# delta_actions = delta_actions.view(n,b,h*d)
|
||||
|
||||
# weighted_delta = w * (delta_actions**2)
|
||||
# weighted_delta =
|
||||
# sum across horizon and mean across particles:
|
||||
# cov_update = torch.diag(torch.mean(torch.sum(weighted_delta.T , dim=0), dim=0))
|
||||
# cov_update = torch.mean(torch.sum(weighted_delta, dim=-2), dim=-2).unsqueeze(
|
||||
# -2
|
||||
# ) # .expand(-1,-1,-1)
|
||||
cov_update = jit_diag_a_cov_update(w, actions, self.mean_action)
|
||||
|
||||
elif self.cov_type == CovType.FULL_A:
|
||||
@@ -241,19 +260,22 @@ class ParallelMPPI(ParticleOptBase, ParallelMPPIConfig):
|
||||
|
||||
else:
|
||||
raise ValueError("Unidentified covariance type in update_distribution")
|
||||
cov_update = jit_blend_cov(self.cov_action, cov_update, self.step_size_cov, self.kappa)
|
||||
return cov_update
|
||||
|
||||
def _update_cov_scale(self):
|
||||
def _update_cov_scale(self, new_cov=None):
|
||||
if new_cov is None:
|
||||
new_cov = self.cov_action
|
||||
if not self.update_cov:
|
||||
return
|
||||
if self.cov_type == CovType.SIGMA_I:
|
||||
self.scale_tril = torch.sqrt(self.cov_action)
|
||||
self.scale_tril = torch.sqrt(new_cov)
|
||||
|
||||
elif self.cov_type == CovType.DIAG_A:
|
||||
self.scale_tril.copy_(torch.sqrt(self.cov_action))
|
||||
self.scale_tril.copy_(torch.sqrt(new_cov))
|
||||
|
||||
elif self.cov_type == CovType.FULL_A:
|
||||
self.scale_tril = matrix_cholesky(self.cov_action)
|
||||
self.scale_tril = matrix_cholesky(new_cov)
|
||||
|
||||
elif self.cov_type == CovType.FULL_HA:
|
||||
raise NotImplementedError
|
||||
@@ -263,44 +285,44 @@ class ParallelMPPI(ParticleOptBase, ParallelMPPIConfig):
|
||||
costs = trajectories.costs
|
||||
actions = trajectories.actions
|
||||
|
||||
total_costs = self._compute_total_cost(costs)
|
||||
# Let's reshape to n_problems now:
|
||||
|
||||
# first find the means before doing exponential utility:
|
||||
w = self._exp_util(total_costs)
|
||||
|
||||
# Update best action
|
||||
if self.sample_mode == SampleMode.BEST:
|
||||
best_idx = torch.argmax(w, dim=-1)
|
||||
self.best_traj.copy_(actions[self.problem_col, best_idx])
|
||||
with profiler.record_function("mppi/get_best"):
|
||||
|
||||
if self.store_rollouts and self.visual_traj is not None:
|
||||
vis_seq = getattr(trajectories.state, self.visual_traj)
|
||||
top_values, top_idx = torch.topk(total_costs, 20, dim=1)
|
||||
self.top_values = top_values
|
||||
self.top_idx = top_idx
|
||||
top_trajs = torch.index_select(vis_seq, 0, top_idx[0])
|
||||
for i in range(1, top_idx.shape[0]):
|
||||
trajs = torch.index_select(
|
||||
vis_seq, 0, top_idx[i] + (self.particles_per_problem * i)
|
||||
)
|
||||
top_trajs = torch.cat((top_trajs, trajs), dim=0)
|
||||
if self.top_trajs is None or top_trajs.shape != self.top_trajs:
|
||||
self.top_trajs = top_trajs
|
||||
else:
|
||||
self.top_trajs.copy_(top_trajs)
|
||||
# Update best action
|
||||
if self.sample_mode == SampleMode.BEST:
|
||||
w = self._exp_util_from_costs(costs)
|
||||
best_idx = torch.argmax(w, dim=-1)
|
||||
self.best_traj.copy_(actions[self.problem_col, best_idx])
|
||||
with profiler.record_function("mppi/store_rollouts"):
|
||||
|
||||
w = w.unsqueeze(-1).unsqueeze(-1)
|
||||
if self.store_rollouts and self.visual_traj is not None:
|
||||
total_costs = self._compute_total_cost(costs)
|
||||
vis_seq = getattr(trajectories.state, self.visual_traj)
|
||||
top_values, top_idx = torch.topk(total_costs, 20, dim=1)
|
||||
self.top_values = top_values
|
||||
self.top_idx = top_idx
|
||||
top_trajs = torch.index_select(vis_seq, 0, top_idx[0])
|
||||
for i in range(1, top_idx.shape[0]):
|
||||
trajs = torch.index_select(
|
||||
vis_seq, 0, top_idx[i] + (self.particles_per_problem * i)
|
||||
)
|
||||
top_trajs = torch.cat((top_trajs, trajs), dim=0)
|
||||
if self.top_trajs is None or top_trajs.shape != self.top_trajs:
|
||||
self.top_trajs = top_trajs
|
||||
else:
|
||||
self.top_trajs.copy_(top_trajs)
|
||||
|
||||
new_mean = self._compute_mean(w, actions)
|
||||
# print(new_mean)
|
||||
if self.update_cov:
|
||||
cov_update = self._compute_covariance(w, actions)
|
||||
new_cov = jit_blend_cov(self.cov_action, cov_update, self.step_size_cov, self.kappa)
|
||||
if not self.update_cov:
|
||||
w = self._exp_util_from_costs(costs)
|
||||
w = w.unsqueeze(-1).unsqueeze(-1)
|
||||
new_mean = self._compute_mean(w, actions)
|
||||
else:
|
||||
new_mean, new_cov = self._compute_mean_covariance(costs, actions)
|
||||
self.cov_action.copy_(new_cov)
|
||||
|
||||
self._update_cov_scale()
|
||||
new_mean = jit_blend_mean(self.mean_action, new_mean, self.step_size_mean)
|
||||
self.mean_action.copy_(new_mean)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -591,20 +613,28 @@ class ParallelMPPI(ParticleOptBase, ParallelMPPIConfig):
|
||||
return super().generate_rollouts(init_act)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_calculate_exp_util(beta: float, total_costs):
|
||||
w = torch.softmax((-1.0 / beta) * total_costs, dim=-1)
|
||||
return w
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_calculate_exp_util_from_costs(costs, gamma_seq, beta: float):
|
||||
cost_seq = gamma_seq * costs
|
||||
cost_seq = torch.sum(cost_seq, dim=-1, keepdim=False) / gamma_seq[..., 0]
|
||||
w = torch.softmax((-1.0 / beta) * cost_seq, dim=-1)
|
||||
return w
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def jit_compute_total_cost(gamma_seq, costs):
|
||||
cost_seq = gamma_seq * costs
|
||||
cost_seq = torch.sum(cost_seq, dim=-1, keepdim=False) / gamma_seq[..., 0]
|
||||
return cost_seq
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_diag_a_cov_update(w, actions, mean_action):
|
||||
delta_actions = actions - mean_action.unsqueeze(-3)
|
||||
|
||||
@@ -616,13 +646,35 @@ def jit_diag_a_cov_update(w, actions, mean_action):
|
||||
return cov_update
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_blend_cov(cov_action, cov_update, step_size_cov: float, kappa: float):
|
||||
new_cov = (1.0 - step_size_cov) * cov_action + step_size_cov * cov_update + kappa
|
||||
return new_cov
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_blend_mean(mean_action, new_mean, step_size_mean: float):
|
||||
mean_update = (1.0 - step_size_mean) * mean_action + step_size_mean * new_mean
|
||||
return mean_update
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def jit_mean_cov_diag_a(
|
||||
costs,
|
||||
actions,
|
||||
gamma_seq,
|
||||
mean_action,
|
||||
cov_action,
|
||||
step_size_mean: float,
|
||||
step_size_cov: float,
|
||||
kappa: float,
|
||||
beta: float,
|
||||
):
|
||||
w = jit_calculate_exp_util_from_costs(costs, gamma_seq, beta)
|
||||
w = w.unsqueeze(-1).unsqueeze(-1)
|
||||
new_mean = torch.sum(w * actions, dim=-3)
|
||||
new_mean = jit_blend_mean(mean_action, new_mean, step_size_mean)
|
||||
cov_update = jit_diag_a_cov_update(w, actions, mean_action)
|
||||
new_cov = jit_blend_cov(cov_action, cov_update, step_size_cov, kappa)
|
||||
new_tril = torch.sqrt(new_cov)
|
||||
return new_mean, new_cov, new_tril
|
||||
|
||||
@@ -263,7 +263,7 @@ class ParticleOptBase(Optimizer, ParticleOptConfig):
|
||||
trajectory.costs = trajectory.costs.view(
|
||||
self.n_problems, self.particles_per_problem, self.horizon
|
||||
)
|
||||
with profiler.record_function("mpc/mppi/update_distribution"):
|
||||
with profiler.record_function("mppi/update_distribution"):
|
||||
self._update_distribution(trajectory)
|
||||
if not self.use_cuda_graph and self.store_debug:
|
||||
self.debug.append(self._get_action_seq(mode=self.sample_mode).clone())
|
||||
|
||||
@@ -18,6 +18,7 @@ import torch.autograd.profiler as profiler
|
||||
|
||||
# CuRobo
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
class SquashType(Enum):
|
||||
@@ -74,7 +75,7 @@ def get_stomp_cov(
|
||||
return cov, scale_tril
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def get_stomp_cov_jit(
|
||||
horizon: int,
|
||||
d_action: int,
|
||||
@@ -245,7 +246,7 @@ def gaussian_kl(mean0, cov0, mean1, cov1, cov_type="full"):
|
||||
return term1 + mahalanobis_dist + term3
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def cost_to_go(cost_seq, gamma_seq, only_first=False):
|
||||
# type: (Tensor, Tensor, bool) -> Tensor
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import CSpaceConfig, RobotConfig
|
||||
from curobo.types.state import JointState
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.tensor_util import cat_sum
|
||||
from curobo.util.tensor_util import cat_sum, cat_sum_horizon
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -104,10 +104,10 @@ class ArmCostConfig:
|
||||
|
||||
@dataclass
|
||||
class ArmBaseConfig(RolloutConfig):
|
||||
model_cfg: KinematicModelConfig
|
||||
cost_cfg: ArmCostConfig
|
||||
constraint_cfg: ArmCostConfig
|
||||
convergence_cfg: ArmCostConfig
|
||||
model_cfg: Optional[KinematicModelConfig] = None
|
||||
cost_cfg: Optional[ArmCostConfig] = None
|
||||
constraint_cfg: Optional[ArmCostConfig] = None
|
||||
convergence_cfg: Optional[ArmCostConfig] = None
|
||||
world_coll_checker: Optional[WorldCollision] = None
|
||||
|
||||
@staticmethod
|
||||
@@ -322,7 +322,9 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
self.null_convergence = DistCost(self.convergence_cfg.null_space_cfg)
|
||||
|
||||
# set start state:
|
||||
start_state = torch.randn((1, self.dynamics_model.d_state), **vars(self.tensor_args))
|
||||
start_state = torch.randn(
|
||||
(1, self.dynamics_model.d_state), **(self.tensor_args.as_torch_dict())
|
||||
)
|
||||
self._start_state = JointState(
|
||||
position=start_state[:, : self.dynamics_model.d_dof],
|
||||
velocity=start_state[:, : self.dynamics_model.d_dof],
|
||||
@@ -366,9 +368,11 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
)
|
||||
cost_list.append(coll_cost)
|
||||
if return_list:
|
||||
|
||||
return cost_list
|
||||
cost = cat_sum(cost_list)
|
||||
if self.sum_horizon:
|
||||
cost = cat_sum_horizon(cost_list)
|
||||
else:
|
||||
cost = cat_sum(cost_list)
|
||||
return cost
|
||||
|
||||
def constraint_fn(
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#
|
||||
# Standard Library
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
@@ -29,8 +29,9 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import RobotConfig
|
||||
from curobo.types.tensor import T_BValue_float, T_BValue_int
|
||||
from curobo.util.helpers import list_idx_if_not_none
|
||||
from curobo.util.logger import log_error, log_info
|
||||
from curobo.util.tensor_util import cat_max, cat_sum
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.tensor_util import cat_max
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .arm_base import ArmBase, ArmBaseConfig, ArmCostConfig
|
||||
@@ -145,7 +146,7 @@ class ArmReacherConfig(ArmBaseConfig):
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def _compute_g_dist_jit(rot_err_norm, goal_dist):
|
||||
# goal_cost = goal_cost.view(cost.shape)
|
||||
# rot_err_norm = rot_err_norm.view(cost.shape)
|
||||
@@ -319,7 +320,12 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
g_dist,
|
||||
)
|
||||
cost_list.append(z_vel)
|
||||
cost = cat_sum(cost_list)
|
||||
with profiler.record_function("cat_sum"):
|
||||
if self.sum_horizon:
|
||||
cost = cat_sum_horizon_reacher(cost_list)
|
||||
else:
|
||||
cost = cat_sum_reacher(cost_list)
|
||||
|
||||
return cost
|
||||
|
||||
def convergence_fn(
|
||||
@@ -466,3 +472,15 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
)
|
||||
for x in pose_costs
|
||||
]
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def cat_sum_reacher(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=0)
|
||||
return cat_tensor
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def cat_sum_horizon_reacher(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=(0, -1))
|
||||
return cat_tensor
|
||||
|
||||
@@ -21,6 +21,7 @@ import warp as wp
|
||||
from curobo.cuda_robot_model.types import JointLimits
|
||||
from curobo.types.robot import JointState
|
||||
from curobo.types.tensor import T_DOF
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.warp import init_warp
|
||||
|
||||
# Local Folder
|
||||
@@ -267,7 +268,7 @@ class BoundCost(CostBase, BoundCostConfig):
|
||||
return super().update_dt(dt)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def forward_bound_cost(p, lower_bounds, upper_bounds, weight):
|
||||
# c = weight * torch.sum(torch.nn.functional.relu(torch.max(lower_bounds - p, p - upper_bounds)), dim=-1)
|
||||
|
||||
@@ -281,7 +282,7 @@ def forward_bound_cost(p, lower_bounds, upper_bounds, weight):
|
||||
return c
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def forward_all_bound_cost(
|
||||
p,
|
||||
v,
|
||||
|
||||
@@ -18,6 +18,7 @@ import torch
|
||||
import warp as wp
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.warp import init_warp
|
||||
|
||||
# Local Folder
|
||||
@@ -41,32 +42,32 @@ class DistCostConfig(CostConfig):
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def L2_DistCost_jit(vec_weight, disp_vec):
|
||||
return torch.norm(vec_weight * disp_vec, p=2, dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_SQL2_DistCost_jit(vec_weight, disp_vec):
|
||||
return torch.sum(torch.square(vec_weight * disp_vec), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_L1_DistCost_jit(vec_weight, disp_vec):
|
||||
return torch.sum(torch.abs(vec_weight * disp_vec), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def L2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
||||
return torch.norm(weight * vec_weight * (g_vec - c_vec), p=2, dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_SQL2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
||||
return torch.sum(torch.square(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_L1_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
||||
return torch.sum(torch.abs(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ import torch
|
||||
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollision
|
||||
from curobo.rollout.cost.cost_base import CostBase, CostConfig
|
||||
from curobo.rollout.dynamics_model.integration_utils import interpolate_kernel, sum_matrix
|
||||
from curobo.util.logger import log_info
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -48,7 +50,11 @@ class PrimitiveCollisionCostConfig(CostConfig):
|
||||
#: post optimization interpolation to not hit any obstacles.
|
||||
activation_distance: Union[torch.Tensor, float] = 0.0
|
||||
|
||||
#: Setting this flag to true will sum the distance colliding obstacles.
|
||||
sum_collisions: bool = True
|
||||
|
||||
#: Setting this flag to true will sum the distance across spheres of the robot.
|
||||
#: Setting to False will only take the max distance
|
||||
sum_distance: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -103,6 +109,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
self._collision_query_buffer.update_buffer_shape(
|
||||
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
|
||||
)
|
||||
if not self.sum_distance:
|
||||
log_info("sum_distance=False will be slower than sum_distance=True")
|
||||
self.return_loss = True
|
||||
dist = self.sweep_check_fn(
|
||||
robot_spheres_in,
|
||||
self._collision_query_buffer,
|
||||
@@ -115,9 +124,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
return_loss=self.return_loss,
|
||||
)
|
||||
if self.classify:
|
||||
cost = weight_collision(dist, self.weight, self.sum_distance)
|
||||
cost = weight_collision(dist, self.sum_distance)
|
||||
else:
|
||||
cost = weight_distance(dist, self.weight, self.sum_distance)
|
||||
cost = weight_distance(dist, self.sum_distance)
|
||||
return cost
|
||||
|
||||
def sweep_fn(self, robot_spheres_in, env_query_idx: Optional[torch.Tensor] = None):
|
||||
@@ -140,6 +149,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
self._collision_query_buffer.update_buffer_shape(
|
||||
sampled_spheres.shape, self.tensor_args, self.world_coll_checker.collision_types
|
||||
)
|
||||
if not self.sum_distance:
|
||||
log_info("sum_distance=False will be slower than sum_distance=True")
|
||||
self.return_loss = True
|
||||
dist = self.coll_check_fn(
|
||||
sampled_spheres.contiguous(),
|
||||
self._collision_query_buffer,
|
||||
@@ -151,9 +163,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
dist = dist.view(batch_size, new_horizon, n_spheres)
|
||||
|
||||
if self.classify:
|
||||
cost = weight_sweep_collision(self.int_sum_mat, dist, self.weight, self.sum_distance)
|
||||
cost = weight_sweep_collision(self.int_sum_mat, dist, self.sum_distance)
|
||||
else:
|
||||
cost = weight_sweep_distance(self.int_sum_mat, dist, self.weight, self.sum_distance)
|
||||
cost = weight_sweep_distance(self.int_sum_mat, dist, self.sum_distance)
|
||||
|
||||
return cost
|
||||
|
||||
@@ -161,6 +173,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
self._collision_query_buffer.update_buffer_shape(
|
||||
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
|
||||
)
|
||||
if not self.sum_distance:
|
||||
log_info("sum_distance=False will be slower than sum_distance=True")
|
||||
self.return_loss = True
|
||||
dist = self.coll_check_fn(
|
||||
robot_spheres_in,
|
||||
self._collision_query_buffer,
|
||||
@@ -168,12 +183,13 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
env_query_idx=env_query_idx,
|
||||
activation_distance=self.activation_distance,
|
||||
return_loss=self.return_loss,
|
||||
sum_collisions=self.sum_collisions,
|
||||
)
|
||||
|
||||
if self.classify:
|
||||
cost = weight_collision(dist, self.weight, self.sum_distance)
|
||||
cost = weight_collision(dist, self.sum_distance)
|
||||
else:
|
||||
cost = weight_distance(dist, self.weight, self.sum_distance)
|
||||
cost = weight_distance(dist, self.sum_distance)
|
||||
return cost
|
||||
|
||||
def update_dt(self, dt: Union[float, torch.Tensor]):
|
||||
@@ -184,31 +200,43 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
return self._collision_query_buffer.get_gradient_buffer()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_sweep_distance(int_mat, dist, weight, sum_cost: bool):
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
@get_torch_jit_decorator()
|
||||
def weight_sweep_distance(int_mat, dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
dist = dist @ int_mat
|
||||
return dist
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_sweep_collision(int_mat, dist, weight, sum_cost: bool):
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
@get_torch_jit_decorator()
|
||||
def weight_sweep_collision(int_mat, dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
|
||||
dist = torch.where(dist > 0, dist + 1.0, dist)
|
||||
dist = dist @ int_mat
|
||||
return dist
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_distance(dist, weight, sum_cost: bool):
|
||||
@get_torch_jit_decorator()
|
||||
def weight_distance(dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
return dist
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_collision(dist, weight, sum_cost: bool):
|
||||
@get_torch_jit_decorator()
|
||||
def weight_collision(dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
|
||||
dist = torch.where(dist > 0, dist + 1.0, dist)
|
||||
return dist
|
||||
|
||||
@@ -17,6 +17,7 @@ import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.rollout.dynamics_model.kinematic_model import TimeTrajConfig
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .cost_base import CostBase, CostConfig
|
||||
@@ -72,7 +73,7 @@ class StopCost(CostBase, StopCostConfig):
|
||||
return cost
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def velocity_cost(vels, weight, max_vel):
|
||||
vel_abs = torch.abs(vels)
|
||||
vel_abs = torch.nn.functional.relu(vel_abs - max_vel[: vels.shape[1]])
|
||||
|
||||
@@ -13,11 +13,14 @@
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .cost_base import CostBase, CostConfig
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def st_cost(ee_pos_batch, vec_weight, weight):
|
||||
ee_plus_one = torch.roll(ee_pos_batch, 1, dims=1)
|
||||
|
||||
|
||||
@@ -11,11 +11,14 @@
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .cost_base import CostBase
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def squared_sum(cost: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
# return weight * torch.square(torch.linalg.norm(cost, dim=-1, ord=1))
|
||||
# return weight * torch.sum(torch.square(cost), dim=-1)
|
||||
@@ -24,7 +27,7 @@ def squared_sum(cost: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sum(torch.square(cost) * weight, dim=-1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def run_squared_sum(
|
||||
cost: torch.Tensor, weight: torch.Tensor, run_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
@@ -35,13 +38,13 @@ def run_squared_sum(
|
||||
# return torch.sum(torch.square(cost), dim=-1) * weight * run_weight
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def backward_squared_sum(cost_vec, w):
|
||||
return 2.0 * w * cost_vec # * g_out.unsqueeze(-1)
|
||||
# return w * g_out.unsqueeze(-1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def backward_run_squared_sum(cost_vec, w, r_w):
|
||||
return 2.0 * w * r_w.unsqueeze(-1) * cost_vec # * g_out.unsqueeze(-1)
|
||||
# return w * r_w.unsqueeze(-1) * cost_vec * g_out.unsqueeze(-1)
|
||||
|
||||
@@ -25,6 +25,7 @@ from curobo.curobolib.tensor_step import (
|
||||
)
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import JointState
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
def build_clique_matrix(horizon, dt, device="cpu", dtype=torch.float32):
|
||||
@@ -154,7 +155,7 @@ def build_start_state_mask(horizon, tensor_args: TensorDeviceType):
|
||||
return mask, n_mask
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_jerk(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix=None):
|
||||
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Optional[Tensor]) -> Tensor
|
||||
|
||||
@@ -176,7 +177,7 @@ def tensor_step_jerk(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_m
|
||||
return state_seq
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def euler_integrate(q_0, u, diag_dt, integrate_matrix):
|
||||
# q_new = q_0 + torch.matmul(integrate_matrix, torch.matmul(diag_dt, u))
|
||||
q_new = q_0 + torch.matmul(integrate_matrix, u)
|
||||
@@ -184,7 +185,7 @@ def euler_integrate(q_0, u, diag_dt, integrate_matrix):
|
||||
return q_new
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_acc(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix=None):
|
||||
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Optional[Tensor]) -> Tensor
|
||||
# This is batch,n_dof
|
||||
@@ -207,7 +208,7 @@ def tensor_step_acc(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_ma
|
||||
return state_seq
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_tensor_step_pos_clique_contiguous(pos_act, start_position, mask, n_mask, fd_1, fd_2, fd_3):
|
||||
state_position = (start_position.unsqueeze(1).transpose(1, 2) @ mask.transpose(0, 1)) + (
|
||||
pos_act.transpose(1, 2) @ n_mask.transpose(0, 1)
|
||||
@@ -222,7 +223,7 @@ def jit_tensor_step_pos_clique_contiguous(pos_act, start_position, mask, n_mask,
|
||||
return state_position, state_vel, state_acc, state_jerk
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_tensor_step_pos_clique(pos_act, start_position, mask, n_mask, fd_1, fd_2, fd_3):
|
||||
state_position = mask @ start_position.unsqueeze(1) + n_mask @ pos_act
|
||||
state_vel = fd_1 @ state_position
|
||||
@@ -231,7 +232,7 @@ def jit_tensor_step_pos_clique(pos_act, start_position, mask, n_mask, fd_1, fd_2
|
||||
return state_position, state_vel, state_acc, state_jerk
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_backward_pos_clique(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2, fd_3):
|
||||
p_grad = (
|
||||
grad_p
|
||||
@@ -247,7 +248,7 @@ def jit_backward_pos_clique(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2,
|
||||
return u_grad
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_backward_pos_clique_contiguous(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2, fd_3):
|
||||
p_grad = grad_p + (
|
||||
grad_j.transpose(-1, -2) @ fd_3
|
||||
@@ -532,7 +533,7 @@ class CliqueTensorStepIdxCentralDifferenceKernel(torch.autograd.Function):
|
||||
start_position,
|
||||
start_velocity,
|
||||
start_acceleration,
|
||||
start_idx,
|
||||
start_idx.contiguous(),
|
||||
traj_dt,
|
||||
out_position.shape[0],
|
||||
out_position.shape[1],
|
||||
@@ -750,7 +751,7 @@ class AccelerationTensorStepIdxKernel(torch.autograd.Function):
|
||||
return u_grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_pos_clique(
|
||||
state: JointState,
|
||||
act: torch.Tensor,
|
||||
@@ -786,7 +787,7 @@ def step_acc_semi_euler(state, act, diag_dt, n_dofs, integrate_matrix):
|
||||
return state_seq
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_acc_semi_euler(
|
||||
state, act, state_seq, diag_dt, integrate_matrix, integrate_matrix_pos
|
||||
):
|
||||
@@ -806,7 +807,7 @@ def tensor_step_acc_semi_euler(
|
||||
return state_seq
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_vel(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix):
|
||||
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Tensor) -> Tensor
|
||||
|
||||
@@ -830,7 +831,7 @@ def tensor_step_vel(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_ma
|
||||
return state_seq
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_pos(state, act, state_seq, fd_matrix):
|
||||
# This is batch,n_dof
|
||||
state_seq.position[:, 0, :] = state.position
|
||||
@@ -850,7 +851,7 @@ def tensor_step_pos(state, act, state_seq, fd_matrix):
|
||||
return state_seq
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_pos_ik(act, state_seq):
|
||||
state_seq.position = act
|
||||
return state_seq
|
||||
@@ -869,7 +870,7 @@ def tensor_linspace(start_tensor, end_tensor, steps=10):
|
||||
|
||||
|
||||
def sum_matrix(h, int_steps, tensor_args):
|
||||
sum_mat = torch.zeros(((h - 1) * int_steps, h), **vars(tensor_args))
|
||||
sum_mat = torch.zeros(((h - 1) * int_steps, h), **(tensor_args.as_torch_dict()))
|
||||
for i in range(h - 1):
|
||||
sum_mat[i * int_steps : i * int_steps + int_steps, i] = 1.0
|
||||
# hack:
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch
|
||||
# CuRobo
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import JointState
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .integration_utils import (
|
||||
@@ -544,7 +545,7 @@ class TensorStepPositionCliqueKernel(TensorStepBase):
|
||||
return new_signal
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(force_jit=True)
|
||||
def filter_signal_jit(signal, kernel):
|
||||
b, h, dof = signal.shape
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from curobo.util.helpers import list_idx_if_not_none
|
||||
from curobo.util.logger import log_info
|
||||
from curobo.util.sample_lib import HaltonGenerator
|
||||
from curobo.util.tensor_util import copy_tensor
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -298,9 +299,9 @@ class Goal(Sequence):
|
||||
if self.goal_pose is not None:
|
||||
self.goal_pose = self.goal_pose.to(tensor_args)
|
||||
if self.goal_state is not None:
|
||||
self.goal_state = self.goal_state.to(**vars(tensor_args))
|
||||
self.goal_state = self.goal_state.to(**(tensor_args.as_torch_dict()))
|
||||
if self.current_state is not None:
|
||||
self.current_state = self.current_state.to(**vars(tensor_args))
|
||||
self.current_state = self.current_state.to(**(tensor_args.as_torch_dict()))
|
||||
return self
|
||||
|
||||
def copy_(self, goal: Goal, update_idx_buffers: bool = True):
|
||||
@@ -350,6 +351,7 @@ class Goal(Sequence):
|
||||
if ref_buffer is not None:
|
||||
ref_buffer = ref_buffer.copy_(buffer)
|
||||
else:
|
||||
log_info("breaking reference")
|
||||
ref_buffer = buffer.clone()
|
||||
return ref_buffer
|
||||
|
||||
@@ -414,6 +416,7 @@ class Goal(Sequence):
|
||||
@dataclass
|
||||
class RolloutConfig:
|
||||
tensor_args: TensorDeviceType
|
||||
sum_horizon: bool = False
|
||||
|
||||
|
||||
class RolloutBase:
|
||||
@@ -578,7 +581,7 @@ class RolloutBase:
|
||||
return q_js
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def tensor_repeat_seeds(tensor, num_seeds: int):
|
||||
a = (
|
||||
tensor.view(tensor.shape[0], 1, tensor.shape[-1])
|
||||
|
||||
@@ -19,7 +19,10 @@ import torch
|
||||
@dataclass(frozen=True)
|
||||
class TensorDeviceType:
|
||||
device: torch.device = torch.device("cuda", 0)
|
||||
dtype: torch.float32 = torch.float32
|
||||
dtype: torch.dtype = torch.float32
|
||||
collision_geometry_dtype: torch.dtype = torch.float32
|
||||
collision_gradient_dtype: torch.dtype = torch.float32
|
||||
collision_distance_dtype: torch.dtype = torch.float32
|
||||
|
||||
@staticmethod
|
||||
def from_basic(device: str, dev_id: int):
|
||||
@@ -36,3 +39,6 @@ class TensorDeviceType:
|
||||
|
||||
def cpu(self):
|
||||
return TensorDeviceType(device=torch.device("cpu"), dtype=self.dtype)
|
||||
|
||||
def as_torch_dict(self):
|
||||
return {"device": self.device, "dtype": self.dtype}
|
||||
|
||||
@@ -39,7 +39,7 @@ class CameraObservation:
|
||||
resolution: Optional[List[int]] = None
|
||||
pose: Optional[Pose] = None
|
||||
intrinsics: Optional[torch.Tensor] = None
|
||||
timestamp: float = 0.0
|
||||
timestamp: Optional[torch.Tensor] = None
|
||||
|
||||
def filter_depth(self, distance: float = 0.01):
|
||||
self.depth_image = torch.where(self.depth_image < distance, 0, self.depth_image)
|
||||
@@ -62,6 +62,8 @@ class CameraObservation:
|
||||
self.projection_rays.copy_(new_data.projection_rays)
|
||||
if self.pose is not None:
|
||||
self.pose.copy_(new_data.pose)
|
||||
if self.timestamp is not None:
|
||||
self.timestamp.copy_(new_data.timestamp)
|
||||
self.resolution = new_data.resolution
|
||||
|
||||
@record_function("camera/clone")
|
||||
@@ -73,13 +75,41 @@ class CameraObservation:
|
||||
intrinsics=self.intrinsics.clone() if self.intrinsics is not None else None,
|
||||
resolution=self.resolution,
|
||||
pose=self.pose.clone() if self.pose is not None else None,
|
||||
timestamp=self.timestamp.clone() if self.timestamp is not None else None,
|
||||
image_segmentation=(
|
||||
self.image_segmentation.clone() if self.image_segmentation is not None else None
|
||||
),
|
||||
projection_matrix=(
|
||||
self.projection_matrix.clone() if self.projection_matrix is not None else None
|
||||
),
|
||||
projection_rays=(
|
||||
self.projection_rays.clone() if self.projection_rays is not None else None
|
||||
),
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
def to(self, device: torch.device):
|
||||
if self.rgb_image is not None:
|
||||
self.rgb_image = self.rgb_image.to(device=device)
|
||||
if self.depth_image is not None:
|
||||
self.depth_image = self.depth_image.to(device=device)
|
||||
|
||||
self.rgb_image = self.rgb_image.to(device=device) if self.rgb_image is not None else None
|
||||
self.depth_image = (
|
||||
self.depth_image.to(device=device) if self.depth_image is not None else None
|
||||
)
|
||||
|
||||
self.image_segmentation = (
|
||||
self.image_segmentation.to(device=device)
|
||||
if self.image_segmentation is not None
|
||||
else None
|
||||
)
|
||||
self.projection_matrix = (
|
||||
self.projection_matrix.to(device=device) if self.projection_matrix is not None else None
|
||||
)
|
||||
self.projection_rays = (
|
||||
self.projection_rays.to(device=device) if self.projection_rays is not None else None
|
||||
)
|
||||
self.intrinsics = self.intrinsics.to(device=device) if self.intrinsics is not None else None
|
||||
self.timestamp = self.timestamp.to(device=device) if self.timestamp is not None else None
|
||||
self.pose = self.pose.to(device=device) if self.pose is not None else None
|
||||
|
||||
return self
|
||||
|
||||
def get_pointcloud(self):
|
||||
@@ -114,3 +144,56 @@ class CameraObservation:
|
||||
self.projection_rays = project_rays
|
||||
|
||||
self.projection_rays.copy_(project_rays)
|
||||
|
||||
def stack(self, new_observation: CameraObservation, dim: int = 0):
|
||||
rgb_image = (
|
||||
torch.stack((self.rgb_image, new_observation.rgb_image), dim=dim)
|
||||
if self.rgb_image is not None
|
||||
else None
|
||||
)
|
||||
depth_image = (
|
||||
torch.stack((self.depth_image, new_observation.depth_image), dim=dim)
|
||||
if self.depth_image is not None
|
||||
else None
|
||||
)
|
||||
image_segmentation = (
|
||||
torch.stack((self.image_segmentation, new_observation.image_segmentation), dim=dim)
|
||||
if self.image_segmentation is not None
|
||||
else None
|
||||
)
|
||||
projection_matrix = (
|
||||
torch.stack((self.projection_matrix, new_observation.projection_matrix), dim=dim)
|
||||
if self.projection_matrix is not None
|
||||
else None
|
||||
)
|
||||
projection_rays = (
|
||||
torch.stack((self.projection_rays, new_observation.projection_rays), dim=dim)
|
||||
if self.projection_rays is not None
|
||||
else None
|
||||
)
|
||||
resolution = self.resolution
|
||||
|
||||
pose = self.pose.stack(new_observation.pose) if self.pose is not None else None
|
||||
|
||||
intrinsics = (
|
||||
torch.stack((self.intrinsics, new_observation.intrinsics), dim=dim)
|
||||
if self.intrinsics is not None
|
||||
else None
|
||||
)
|
||||
timestamp = (
|
||||
torch.stack((self.timestamp, new_observation.timestamp), dim=dim)
|
||||
if self.timestamp is not None
|
||||
else None
|
||||
)
|
||||
return CameraObservation(
|
||||
name=self.name,
|
||||
rgb_image=rgb_image,
|
||||
depth_image=depth_image,
|
||||
image_segmentation=image_segmentation,
|
||||
projection_matrix=projection_matrix,
|
||||
projection_rays=projection_rays,
|
||||
resolution=resolution,
|
||||
pose=pose,
|
||||
intrinsics=intrinsics,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from __future__ import annotations
|
||||
|
||||
# Standard Library
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Sequence
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
# Third Party
|
||||
import numpy as np
|
||||
@@ -35,6 +35,7 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.util.helpers import list_idx_if_not_none
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.tensor_util import clone_if_not_none, copy_tensor
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .tensor import T_BPosition, T_BQuaternion, T_BRotation
|
||||
@@ -263,8 +264,17 @@ class Pose(Sequence):
|
||||
# rotation=clone_if_not_none(self.rotation),
|
||||
)
|
||||
|
||||
def to(self, tensor_args: TensorDeviceType):
|
||||
t_type = vars(tensor_args)
|
||||
def to(
|
||||
self,
|
||||
tensor_args: Optional[TensorDeviceType] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if tensor_args is None and device is None:
|
||||
log_error("Pose.to() requires tensor_args or device")
|
||||
if tensor_args is not None:
|
||||
t_type = vars(tensor_args.as_torch_dict())
|
||||
else:
|
||||
t_type = {"device": device}
|
||||
if self.position is not None:
|
||||
self.position = self.position.to(**t_type)
|
||||
if self.quaternion is not None:
|
||||
@@ -338,7 +348,7 @@ class Pose(Sequence):
|
||||
return p_distance, quat_distance
|
||||
|
||||
def angular_distance(self, other_pose: Pose, use_phi3: bool = False):
|
||||
"""This function computes the angular distance \phi_3.
|
||||
"""This function computes the angular distance phi_3.
|
||||
|
||||
See Huynh, Du Q. "Metrics for 3D rotations: Comparison and analysis." Journal of Mathematical
|
||||
Imaging and Vision 35 (2009): 155-164.
|
||||
@@ -461,9 +471,9 @@ def quat_multiply(q1, q2, q_res):
|
||||
return q_res
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def angular_distance_phi3(goal_quat, current_quat):
|
||||
"""This function computes the angular distance \phi_3.
|
||||
"""This function computes the angular distance phi_3.
|
||||
|
||||
See Huynh, Du Q. "Metrics for 3D rotations: Comparison and analysis." Journal of Mathematical
|
||||
Imaging and Vision 35 (2009): 155-164.
|
||||
@@ -524,7 +534,7 @@ class OrientationError(Function):
|
||||
return None, grad_mul, None
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def normalize_quaternion(in_quaternion: torch.Tensor) -> torch.Tensor:
|
||||
k = torch.sign(in_quaternion[..., 0:1])
|
||||
# NOTE: torch sign returns 0 as sign value when value is 0.0
|
||||
|
||||
@@ -30,6 +30,7 @@ from curobo.util.tensor_util import (
|
||||
fd_tensor,
|
||||
tensor_repeat_seeds,
|
||||
)
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -211,10 +212,10 @@ class JointState(State):
|
||||
j = None
|
||||
v = a = None
|
||||
max_idx = 0
|
||||
if isinstance(idx, List):
|
||||
idx = torch.as_tensor(idx, device=self.position.device, dtype=torch.long)
|
||||
if isinstance(idx, int):
|
||||
max_idx = idx
|
||||
elif isinstance(idx, List):
|
||||
max_idx = max(idx)
|
||||
elif isinstance(idx, torch.Tensor):
|
||||
max_idx = torch.max(idx)
|
||||
if max_idx >= self.position.shape[0]:
|
||||
@@ -223,31 +224,19 @@ class JointState(State):
|
||||
+ " index out of range, current state is of length "
|
||||
+ str(self.position.shape)
|
||||
)
|
||||
p = self.position[idx]
|
||||
if self.velocity is not None:
|
||||
if max_idx >= self.velocity.shape[0]:
|
||||
raise ValueError(
|
||||
str(max_idx)
|
||||
+ " index out of range, current velocity is of length "
|
||||
+ str(self.velocity.shape)
|
||||
)
|
||||
v = self.velocity[idx]
|
||||
if self.acceleration is not None:
|
||||
if max_idx >= self.acceleration.shape[0]:
|
||||
raise ValueError(
|
||||
str(max_idx)
|
||||
+ " index out of range, current acceleration is of length "
|
||||
+ str(self.acceleration.shape)
|
||||
)
|
||||
a = self.acceleration[idx]
|
||||
if self.jerk is not None:
|
||||
if max_idx >= self.jerk.shape[0]:
|
||||
raise ValueError(
|
||||
str(max_idx)
|
||||
+ " index out of range, current jerk is of length "
|
||||
+ str(self.jerk.shape)
|
||||
)
|
||||
j = self.jerk[idx]
|
||||
if isinstance(idx, int):
|
||||
p, v, a, j = jit_get_index_int(
|
||||
self.position, self.velocity, self.acceleration, self.jerk, idx
|
||||
)
|
||||
elif isinstance(idx, torch.Tensor):
|
||||
p, v, a, j = jit_get_index(
|
||||
self.position, self.velocity, self.acceleration, self.jerk, idx
|
||||
)
|
||||
else:
|
||||
p, v, a, j = fn_get_index(
|
||||
self.position, self.velocity, self.acceleration, self.jerk, idx
|
||||
)
|
||||
|
||||
return JointState(p, v, a, joint_names=self.joint_names, jerk=j)
|
||||
|
||||
def __len__(self):
|
||||
@@ -514,6 +503,88 @@ class JointState(State):
|
||||
jerk = self.jerk * (dt**3)
|
||||
return JointState(self.position, vel, acc, self.joint_names, jerk, self.tensor_args)
|
||||
|
||||
def scale_by_dt(self, dt: torch.Tensor, new_dt: torch.Tensor):
|
||||
vel, acc, jerk = jit_js_scale(self.velocity, self.acceleration, self.jerk, dt, new_dt)
|
||||
|
||||
return JointState(self.position, vel, acc, self.joint_names, jerk, self.tensor_args)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.position.shape
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def jit_js_scale(
|
||||
vel: Union[None, torch.Tensor],
|
||||
acc: Union[None, torch.Tensor],
|
||||
jerk: Union[None, torch.Tensor],
|
||||
dt: torch.Tensor,
|
||||
new_dt: torch.Tensor,
|
||||
):
|
||||
scale_dt = dt / new_dt
|
||||
if vel is not None:
|
||||
vel = vel * scale_dt
|
||||
if acc is not None:
|
||||
acc = acc * scale_dt * scale_dt
|
||||
if jerk is not None:
|
||||
jerk = jerk * scale_dt * scale_dt * scale_dt
|
||||
return vel, acc, jerk
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def jit_get_index(
|
||||
position: torch.Tensor,
|
||||
velocity: Union[torch.Tensor, None],
|
||||
acc: Union[torch.Tensor, None],
|
||||
jerk: Union[torch.Tensor, None],
|
||||
idx: torch.Tensor,
|
||||
):
|
||||
|
||||
position = position[idx]
|
||||
if velocity is not None:
|
||||
velocity = velocity[idx]
|
||||
if acc is not None:
|
||||
acc = acc[idx]
|
||||
if jerk is not None:
|
||||
jerk = jerk[idx]
|
||||
|
||||
return position, velocity, acc, jerk
|
||||
|
||||
|
||||
def fn_get_index(
|
||||
position: torch.Tensor,
|
||||
velocity: Union[torch.Tensor, None],
|
||||
acc: Union[torch.Tensor, None],
|
||||
jerk: Union[torch.Tensor, None],
|
||||
idx: torch.Tensor,
|
||||
):
|
||||
|
||||
position = position[idx]
|
||||
if velocity is not None:
|
||||
velocity = velocity[idx]
|
||||
if acc is not None:
|
||||
acc = acc[idx]
|
||||
if jerk is not None:
|
||||
jerk = jerk[idx]
|
||||
|
||||
return position, velocity, acc, jerk
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def jit_get_index_int(
|
||||
position: torch.Tensor,
|
||||
velocity: Union[torch.Tensor, None],
|
||||
acc: Union[torch.Tensor, None],
|
||||
jerk: Union[torch.Tensor, None],
|
||||
idx: int,
|
||||
):
|
||||
|
||||
position = position[idx]
|
||||
if velocity is not None:
|
||||
velocity = velocity[idx]
|
||||
if acc is not None:
|
||||
acc = acc[idx]
|
||||
if jerk is not None:
|
||||
jerk = jerk[idx]
|
||||
|
||||
return position, velocity, acc, jerk
|
||||
|
||||
@@ -9,54 +9,24 @@
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
|
||||
""" This module contains aliases for structured Tensors, improving readability."""
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.logger import log_warn
|
||||
|
||||
try:
|
||||
# Third Party
|
||||
from torchtyping import TensorType
|
||||
except ImportError:
|
||||
log_warn("torchtyping could not be imported, falling back to basic types")
|
||||
TensorType = None
|
||||
# Third Party
|
||||
import torch
|
||||
b_dof = ["batch", "dof"]
|
||||
b_value = ["batch", "value"]
|
||||
bh_value = ["batch", "horizon", "value"]
|
||||
bh_dof = ["batch", "horizon", "dof"]
|
||||
h_dof = ["horizon", "dof"]
|
||||
T_DOF = torch.Tensor #: Tensor of shape [degrees of freedom]
|
||||
T_BDOF = torch.Tensor #: Tensor of shape [batch, degrees of freedom]
|
||||
T_BHDOF_float = torch.Tensor #: Tensor of shape [batch, horizon, degrees of freedom]
|
||||
T_HDOF_float = torch.Tensor #: Tensor of shape [horizon, degrees of freedom]
|
||||
|
||||
if TensorType is not None:
|
||||
T_DOF = TensorType[tuple(["dof"] + [float])]
|
||||
T_BDOF = TensorType[tuple(b_dof + [float])]
|
||||
T_BValue_float = TensorType[tuple(b_value + [float])]
|
||||
T_BHValue_float = TensorType[tuple(bh_value + [float])]
|
||||
T_BValue_bool = TensorType[tuple(b_value + [bool])]
|
||||
T_BValue_int = TensorType[tuple(b_value + [int])]
|
||||
T_BValue_float = torch.Tensor #: Float Tensor of shape [batch, 1].
|
||||
T_BHValue_float = torch.Tensor #: Float Tensor of shape [batch, horizon, 1].
|
||||
T_BValue_bool = torch.Tensor #: Bool Tensor of shape [batch, horizon, 1].
|
||||
T_BValue_int = torch.Tensor #: Int Tensor of shape [batch, horizon, 1].
|
||||
|
||||
T_BPosition = TensorType["batch", "xyz":3, float]
|
||||
T_BQuaternion = TensorType["batch", "wxyz":4, float]
|
||||
T_BRotation = TensorType["batch", 3, 3, float]
|
||||
T_Position = TensorType["xyz":3, float]
|
||||
T_Quaternion = TensorType["wxyz":4, float]
|
||||
T_Rotation = TensorType[3, 3, float]
|
||||
|
||||
T_BHDOF_float = TensorType[tuple(bh_dof + [float])]
|
||||
T_HDOF_float = TensorType[tuple(h_dof + [float])]
|
||||
else:
|
||||
T_DOF = torch.Tensor
|
||||
T_BDOF = torch.Tensor
|
||||
T_BValue_float = torch.Tensor
|
||||
T_BHValue_float = torch.Tensor
|
||||
T_BValue_bool = torch.Tensor
|
||||
T_BValue_int = torch.Tensor
|
||||
|
||||
T_BPosition = torch.Tensor
|
||||
T_BQuaternion = torch.Tensor
|
||||
T_BRotation = torch.Tensor
|
||||
T_Position = torch.Tensor
|
||||
T_Quaternion = torch.Tensor
|
||||
T_Rotation = torch.Tensor
|
||||
|
||||
T_BHDOF_float = torch.Tensor
|
||||
T_HDOF_float = torch.Tensor
|
||||
T_BPosition = torch.Tensor #: Tensor of shape [batch, 3].
|
||||
T_BQuaternion = torch.Tensor #: Tensor of shape [batch, 4].
|
||||
T_BRotation = torch.Tensor #: Tensor of shape [batch, 3,3].
|
||||
|
||||
@@ -25,6 +25,7 @@ from torch.distributions.multivariate_normal import MultivariateNormal
|
||||
# CuRobo
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.util.logger import log_error, log_warn
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from ..opt.particle.particle_opt_utils import get_stomp_cov
|
||||
@@ -511,7 +512,7 @@ class HaltonGenerator:
|
||||
return gaussian_halton_samples
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def gaussian_transform(
|
||||
uniform_samples: torch.Tensor, proj_mat: torch.Tensor, i_mat: torch.Tensor, std_dev: float
|
||||
):
|
||||
|
||||
@@ -14,6 +14,9 @@ from typing import List
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
def check_tensor_shapes(new_tensor: torch.Tensor, mem_tensor: torch.Tensor):
|
||||
if not isinstance(mem_tensor, torch.Tensor):
|
||||
@@ -65,13 +68,19 @@ def clone_if_not_none(x):
|
||||
return None
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def cat_sum(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=0)
|
||||
return cat_tensor
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def cat_sum_horizon(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=(0, -1))
|
||||
return cat_tensor
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def cat_max(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.max(torch.stack(tensor_list, dim=0), dim=0)[0]
|
||||
return cat_tensor
|
||||
@@ -85,7 +94,7 @@ def tensor_repeat_seeds(tensor, num_seeds):
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fd_tensor(p: torch.Tensor, dt: torch.Tensor):
|
||||
out = ((torch.roll(p, -1, -2) - p) * (1 / dt).unsqueeze(-1))[..., :-1, :]
|
||||
return out
|
||||
|
||||
@@ -8,12 +8,15 @@
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
# Standard Library
|
||||
import os
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.logger import log_warn
|
||||
from curobo.util.logger import log_info, log_warn
|
||||
|
||||
|
||||
def find_first_idx(array, value, EQUAL=False):
|
||||
@@ -31,13 +34,119 @@ def find_last_idx(array, value):
|
||||
|
||||
def is_cuda_graph_available():
|
||||
if version.parse(torch.__version__) < version.parse("1.10"):
|
||||
log_warn("WARNING: Disabling CUDA Graph as pytorch < 1.10")
|
||||
log_warn("Disabling CUDA Graph as pytorch < 1.10")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_torch_compile_available():
|
||||
force_compile = os.environ.get("CUROBO_TORCH_COMPILE_FORCE")
|
||||
if force_compile is not None and bool(int(force_compile)):
|
||||
return True
|
||||
if version.parse(torch.__version__) < version.parse("2.0"):
|
||||
log_warn("WARNING: Disabling compile as pytorch < 2.0")
|
||||
log_info("Disabling torch.compile as pytorch < 2.0")
|
||||
return False
|
||||
|
||||
env_variable = os.environ.get("CUROBO_TORCH_COMPILE_DISABLE")
|
||||
|
||||
if env_variable is None:
|
||||
log_info("Environment variable for CUROBO_TORCH_COMPILE is not set, Disabling.")
|
||||
|
||||
return False
|
||||
|
||||
if bool(int(env_variable)):
|
||||
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Disable")
|
||||
return False
|
||||
|
||||
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Enable")
|
||||
|
||||
try:
|
||||
torch.compile
|
||||
except:
|
||||
log_info("Could not find torch.compile, disabling Torch Compile.")
|
||||
return False
|
||||
try:
|
||||
torch._dynamo
|
||||
except:
|
||||
log_info("Could not find torch._dynamo, disabling Torch Compile.")
|
||||
return False
|
||||
try:
|
||||
# Third Party
|
||||
import triton
|
||||
except:
|
||||
log_info("Could not find triton, disabling Torch Compile.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_torch_compile_options() -> dict:
|
||||
options = {}
|
||||
if is_torch_compile_available():
|
||||
# Third Party
|
||||
from torch._inductor import config
|
||||
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
use_options = {
|
||||
"max_autotune": True,
|
||||
"use_mixed_mm": True,
|
||||
"conv_1x1_as_mm": True,
|
||||
"coordinate_descent_tuning": True,
|
||||
"epilogue_fusion": False,
|
||||
"coordinate_descent_check_all_directions": True,
|
||||
"force_fuse_int_mm_with_mul": True,
|
||||
"triton.cudagraphs": False,
|
||||
"aggressive_fusion": True,
|
||||
"split_reductions": False,
|
||||
"worker_start_method": "spawn",
|
||||
}
|
||||
for k in use_options.keys():
|
||||
if hasattr(config, k):
|
||||
options[k] = use_options[k]
|
||||
else:
|
||||
log_info("Not found in torch.compile: " + k)
|
||||
return options
|
||||
|
||||
|
||||
def disable_torch_compile_global():
|
||||
if is_torch_compile_available():
|
||||
torch._dynamo.config.disable = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def set_torch_compile_global_options():
|
||||
if is_torch_compile_available():
|
||||
# Third Party
|
||||
from torch._inductor import config
|
||||
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
if hasattr(config, "conv_1x1_as_mm"):
|
||||
torch._inductor.config.conv_1x1_as_mm = True
|
||||
if hasattr(config, "coordinate_descent_tuning"):
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
if hasattr(config, "epilogue_fusion"):
|
||||
torch._inductor.config.epilogue_fusion = False
|
||||
if hasattr(config, "coordinate_descent_check_all_directions"):
|
||||
torch._inductor.config.coordinate_descent_check_all_directions = True
|
||||
if hasattr(config, "force_fuse_int_mm_with_mul"):
|
||||
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
||||
if hasattr(config, "use_mixed_mm"):
|
||||
torch._inductor.config.use_mixed_mm = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_torch_jit_decorator(
|
||||
force_jit: bool = False, dynamic: bool = True, only_valid_for_compile: bool = False
|
||||
):
|
||||
if not force_jit and is_torch_compile_available():
|
||||
return torch.compile(options=get_torch_compile_options(), dynamic=dynamic)
|
||||
elif not only_valid_for_compile:
|
||||
return torch.jit.script
|
||||
else:
|
||||
return empty_decorator
|
||||
|
||||
|
||||
def empty_decorator(function):
|
||||
return function
|
||||
|
||||
@@ -25,6 +25,7 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import JointState
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.sample_lib import bspline
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.warp_interpolation import get_cuda_linear_interpolation
|
||||
|
||||
|
||||
@@ -114,7 +115,7 @@ def get_spline_interpolated_trajectory(raw_traj, des_horizon, degree=5):
|
||||
|
||||
for i in range(cpu_traj.shape[-1]):
|
||||
retimed_traj[:, i] = bspline(cpu_traj[:, i], n=des_horizon, degree=degree)
|
||||
retimed_traj = retimed_traj.to(**vars(tensor_args))
|
||||
retimed_traj = retimed_traj.to(**(tensor_args.as_torch_dict()))
|
||||
return retimed_traj
|
||||
|
||||
|
||||
@@ -385,7 +386,7 @@ def get_interpolated_trajectory(
|
||||
kind=kind,
|
||||
last_step=des_horizon,
|
||||
)
|
||||
retimed_traj = retimed_traj.to(**vars(tensor_args))
|
||||
retimed_traj = retimed_traj.to(**(tensor_args.as_torch_dict()))
|
||||
out_traj_state.position[b, :interpolation_steps, :] = retimed_traj
|
||||
out_traj_state.position[b, interpolation_steps:, :] = retimed_traj[
|
||||
interpolation_steps - 1 : interpolation_steps, :
|
||||
@@ -438,7 +439,39 @@ def linear_smooth(
|
||||
return y_new
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def calculate_dt_fixed(
|
||||
vel: torch.Tensor,
|
||||
acc: torch.Tensor,
|
||||
jerk: torch.Tensor,
|
||||
max_vel: torch.Tensor,
|
||||
max_acc: torch.Tensor,
|
||||
max_jerk: torch.Tensor,
|
||||
raw_dt: torch.Tensor,
|
||||
interpolation_dt: float,
|
||||
):
|
||||
# compute scaled dt:
|
||||
max_v_arr = torch.max(torch.abs(vel), dim=-2)[0] # output is batch, dof
|
||||
|
||||
max_acc_arr = torch.max(torch.abs(acc), dim=-2)[0]
|
||||
max_jerk_arr = torch.max(torch.abs(jerk), dim=-2)[0]
|
||||
|
||||
vel_scale_dt = (max_v_arr) / (max_vel.view(1, max_v_arr.shape[-1])) # batch,dof
|
||||
acc_scale_dt = max_acc_arr / (max_acc.view(1, max_acc_arr.shape[-1]))
|
||||
jerk_scale_dt = max_jerk_arr / (max_jerk.view(1, max_jerk_arr.shape[-1]))
|
||||
|
||||
dt_score_vel = raw_dt * torch.max(vel_scale_dt, dim=-1)[0] # batch, 1
|
||||
dt_score_acc = raw_dt * torch.sqrt((torch.max(acc_scale_dt, dim=-1)[0]))
|
||||
dt_score_jerk = raw_dt * torch.pow((torch.max(jerk_scale_dt, dim=-1)[0]), 1 / 3)
|
||||
dt_score = torch.maximum(dt_score_vel, dt_score_acc)
|
||||
dt_score = torch.maximum(dt_score, dt_score_jerk)
|
||||
dt_score = torch.clamp(dt_score, interpolation_dt, raw_dt)
|
||||
# NOTE: this dt score is not dt, rather a scaling to convert velocity, acc, jerk that was
|
||||
# computed with raw_dt to a new dt
|
||||
return dt_score
|
||||
|
||||
|
||||
@get_torch_jit_decorator(force_jit=True)
|
||||
def calculate_dt(
|
||||
vel: torch.Tensor,
|
||||
acc: torch.Tensor,
|
||||
@@ -470,7 +503,7 @@ def calculate_dt(
|
||||
return dt_score
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(force_jit=True)
|
||||
def calculate_dt_no_clamp(
|
||||
vel: torch.Tensor,
|
||||
acc: torch.Tensor,
|
||||
@@ -497,7 +530,7 @@ def calculate_dt_no_clamp(
|
||||
return dt_score
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def calculate_tsteps(
|
||||
vel: torch.Tensor,
|
||||
acc: torch.Tensor,
|
||||
@@ -506,13 +539,15 @@ def calculate_tsteps(
|
||||
max_vel: torch.Tensor,
|
||||
max_acc: torch.Tensor,
|
||||
max_jerk: torch.Tensor,
|
||||
raw_dt: float,
|
||||
raw_dt: torch.Tensor,
|
||||
min_dt: float,
|
||||
horizon: int,
|
||||
optimize_dt: bool = True,
|
||||
):
|
||||
# compute scaled dt:
|
||||
opt_dt = calculate_dt(vel, acc, jerk, max_vel, max_acc, max_jerk, raw_dt, interpolation_dt)
|
||||
opt_dt = calculate_dt_fixed(
|
||||
vel, acc, jerk, max_vel, max_acc, max_jerk, raw_dt, interpolation_dt
|
||||
)
|
||||
if not optimize_dt:
|
||||
opt_dt[:] = raw_dt
|
||||
traj_steps = (torch.ceil((horizon - 1) * ((opt_dt) / interpolation_dt))).to(dtype=torch.int32)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# Standard Library
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
|
||||
# Third Party
|
||||
|
||||
@@ -19,13 +19,16 @@ from torch.profiler import record_function
|
||||
# CuRobo
|
||||
from curobo.cuda_robot_model.cuda_robot_model import CudaRobotModel
|
||||
from curobo.geom.cv import get_projection_rays, project_depth_using_rays
|
||||
from curobo.geom.types import PointCloud
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.camera import CameraObservation
|
||||
from curobo.types.math import Pose
|
||||
from curobo.types.robot import RobotConfig
|
||||
from curobo.types.state import JointState
|
||||
from curobo.util.logger import log_error
|
||||
from curobo.util.torch_utils import (
|
||||
get_torch_compile_options,
|
||||
get_torch_jit_decorator,
|
||||
is_torch_compile_available,
|
||||
)
|
||||
from curobo.util_file import get_robot_configs_path, join_path, load_yaml
|
||||
from curobo.wrap.model.robot_world import RobotWorld, RobotWorldConfig
|
||||
|
||||
@@ -36,6 +39,7 @@ class RobotSegmenter:
|
||||
robot_world: RobotWorld,
|
||||
distance_threshold: float = 0.05,
|
||||
use_cuda_graph: bool = True,
|
||||
ops_dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
self._robot_world = robot_world
|
||||
self._projection_rays = None
|
||||
@@ -48,11 +52,12 @@ class RobotSegmenter:
|
||||
self._use_cuda_graph = use_cuda_graph
|
||||
self.tensor_args = robot_world.tensor_args
|
||||
self.distance_threshold = distance_threshold
|
||||
self._ops_dtype = ops_dtype
|
||||
|
||||
@staticmethod
|
||||
def from_robot_file(
|
||||
robot_file: Union[str, Dict],
|
||||
collision_sphere_buffer: Optional[float],
|
||||
collision_sphere_buffer: Optional[float] = None,
|
||||
distance_threshold: float = 0.05,
|
||||
use_cuda_graph: bool = True,
|
||||
tensor_args: TensorDeviceType = TensorDeviceType(),
|
||||
@@ -78,7 +83,7 @@ class RobotSegmenter:
|
||||
def get_pointcloud_from_depth(self, camera_obs: CameraObservation):
|
||||
if self._projection_rays is None:
|
||||
self.update_camera_projection(camera_obs)
|
||||
depth_image = camera_obs.depth_image
|
||||
depth_image = camera_obs.depth_image.to(dtype=self._ops_dtype)
|
||||
if len(depth_image.shape) == 2:
|
||||
depth_image = depth_image.unsqueeze(0)
|
||||
points = project_depth_using_rays(depth_image, self._projection_rays)
|
||||
@@ -91,7 +96,7 @@ class RobotSegmenter:
|
||||
intrinsics = intrinsics.unsqueeze(0)
|
||||
project_rays = get_projection_rays(
|
||||
camera_obs.depth_image.shape[-2], camera_obs.depth_image.shape[-1], intrinsics
|
||||
)
|
||||
).to(dtype=self._ops_dtype)
|
||||
|
||||
if self._projection_rays is None:
|
||||
self._projection_rays = project_rays
|
||||
@@ -157,8 +162,12 @@ class RobotSegmenter:
|
||||
def _mask_op(self, camera_obs, q):
|
||||
if len(q.shape) == 1:
|
||||
q = q.unsqueeze(0)
|
||||
|
||||
robot_spheres = self._robot_world.get_kinematics(q).link_spheres_tensor
|
||||
|
||||
points = self.get_pointcloud_from_depth(camera_obs)
|
||||
camera_to_robot = camera_obs.pose
|
||||
points = points.to(dtype=torch.float32)
|
||||
|
||||
if self._out_points_buffer is None:
|
||||
self._out_points_buffer = points.clone()
|
||||
@@ -181,9 +190,9 @@ class RobotSegmenter:
|
||||
|
||||
out_points = points_in_robot_frame
|
||||
|
||||
dist = self._robot_world.get_point_robot_distance(out_points, q)
|
||||
|
||||
mask, filtered_image = mask_image(camera_obs.depth_image, dist, self.distance_threshold)
|
||||
mask, filtered_image = mask_spheres_image(
|
||||
camera_obs.depth_image, robot_spheres, out_points, self.distance_threshold
|
||||
)
|
||||
|
||||
return mask, filtered_image
|
||||
|
||||
@@ -200,7 +209,7 @@ class RobotSegmenter:
|
||||
return self.kinematics.base_link
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def mask_image(
|
||||
image: torch.Tensor, distance: torch.Tensor, distance_threshold: float
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -212,3 +221,36 @@ def mask_image(
|
||||
mask = torch.logical_and((image > 0.0), (distance > -distance_threshold))
|
||||
filtered_image = torch.where(mask, 0, image)
|
||||
return mask, filtered_image
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def mask_spheres_image(
|
||||
image: torch.Tensor,
|
||||
link_spheres_tensor: torch.Tensor,
|
||||
points: torch.Tensor,
|
||||
distance_threshold: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
if link_spheres_tensor.shape[0] != 1:
|
||||
assert link_spheres_tensor.shape[0] == points.shape[0]
|
||||
if len(points.shape) == 2:
|
||||
points = points.unsqueeze(0)
|
||||
|
||||
robot_spheres = link_spheres_tensor.view(link_spheres_tensor.shape[0], -1, 4).contiguous()
|
||||
robot_spheres = robot_spheres.unsqueeze(-3)
|
||||
|
||||
robot_radius = robot_spheres[..., 3]
|
||||
points = points.unsqueeze(-2)
|
||||
sph_distance = -1 * (
|
||||
torch.linalg.norm(points - robot_spheres[..., :3], dim=-1) - robot_radius
|
||||
) # b, n_spheres
|
||||
distance = torch.max(sph_distance, dim=-1)[0]
|
||||
|
||||
distance = distance.view(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2],
|
||||
)
|
||||
mask = torch.logical_and((image > 0.0), (distance > -distance_threshold))
|
||||
filtered_image = torch.where(mask, 0, image)
|
||||
return mask, filtered_image
|
||||
|
||||
@@ -36,6 +36,11 @@ from curobo.types.robot import RobotConfig
|
||||
from curobo.types.state import JointState
|
||||
from curobo.util.logger import log_error
|
||||
from curobo.util.sample_lib import HaltonGenerator
|
||||
from curobo.util.torch_utils import (
|
||||
get_torch_compile_options,
|
||||
get_torch_jit_decorator,
|
||||
is_torch_compile_available,
|
||||
)
|
||||
from curobo.util.warp import init_warp
|
||||
from curobo.util_file import get_robot_configs_path, get_world_configs_path, join_path, load_yaml
|
||||
|
||||
@@ -192,6 +197,9 @@ class RobotWorld(RobotWorldConfig):
|
||||
def update_world(self, world_config: WorldConfig):
|
||||
self.world_model.load_collision_model(world_config)
|
||||
|
||||
def clear_world_cache(self):
|
||||
self.world_model.clear_cache()
|
||||
|
||||
def get_collision_distance(
|
||||
self, x_sph: torch.Tensor, env_query_idx: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
@@ -364,16 +372,7 @@ class RobotWorld(RobotWorldConfig):
|
||||
if len(q.shape) == 1:
|
||||
log_error("q should be of shape [b, dof]")
|
||||
kin_state = self.get_kinematics(q)
|
||||
b, n = None, None
|
||||
if len(points.shape) == 3:
|
||||
b, n, _ = points.shape
|
||||
points = points.view(b * n, 3)
|
||||
|
||||
pt_distance = point_robot_distance(kin_state.link_spheres_tensor, points)
|
||||
|
||||
if b is not None:
|
||||
pt_distance = pt_distance.view(b, n)
|
||||
|
||||
return pt_distance
|
||||
|
||||
def get_active_js(self, full_js: JointState):
|
||||
@@ -382,27 +381,50 @@ class RobotWorld(RobotWorldConfig):
|
||||
return out_js
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def sum_mask(d1, d2, d3):
|
||||
d_total = d1 + d2 + d3
|
||||
d_mask = d_total == 0.0
|
||||
return d_mask.view(-1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def mask(d1, d2, d3):
|
||||
d_total = d1 + d2 + d3
|
||||
d_mask = d_total == 0.0
|
||||
return d_mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def point_robot_distance(link_spheres_tensor, points):
|
||||
robot_spheres = link_spheres_tensor.view(1, -1, 4).contiguous()
|
||||
robot_radius = robot_spheres[:, :, 3]
|
||||
points = points.unsqueeze(1)
|
||||
sph_distance = (
|
||||
torch.linalg.norm(points - robot_spheres[:, :, :3], dim=-1) - robot_radius
|
||||
"""Compute distance between robot and points
|
||||
|
||||
Args:
|
||||
link_spheres_tensor: [batch_robot, n_robot_spheres, 4]
|
||||
points: [batch_points, n_points, 3]
|
||||
|
||||
Returns:
|
||||
distance: [batch_points, n_points]
|
||||
"""
|
||||
if link_spheres_tensor.shape[0] != 1:
|
||||
assert link_spheres_tensor.shape[0] == points.shape[0]
|
||||
squeeze_shape = False
|
||||
n = 1
|
||||
if len(points.shape) == 2:
|
||||
squeeze_shape = True
|
||||
n, _ = points.shape
|
||||
points = points.unsqueeze(0)
|
||||
|
||||
robot_spheres = link_spheres_tensor.view(link_spheres_tensor.shape[0], -1, 4).contiguous()
|
||||
robot_spheres = robot_spheres.unsqueeze(-3)
|
||||
|
||||
robot_radius = robot_spheres[..., 3]
|
||||
points = points.unsqueeze(-2)
|
||||
sph_distance = -1 * (
|
||||
torch.linalg.norm(points - robot_spheres[..., :3], dim=-1) - robot_radius
|
||||
) # b, n_spheres
|
||||
pt_distance = torch.max(-1 * sph_distance, dim=-1)[0]
|
||||
pt_distance = torch.max(sph_distance, dim=-1)[0]
|
||||
|
||||
if squeeze_shape:
|
||||
pt_distance = pt_distance.view(n)
|
||||
return pt_distance
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch.autograd.profiler as profiler
|
||||
# CuRobo
|
||||
from curobo.types.robot import JointState
|
||||
from curobo.types.tensor import T_DOF
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.trajectory import calculate_dt
|
||||
|
||||
|
||||
@@ -32,7 +33,7 @@ class TrajEvaluatorConfig:
|
||||
max_dt: float = 0.1
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def compute_path_length(vel, traj_dt, cspace_distance_weight):
|
||||
pl = torch.sum(
|
||||
torch.sum(torch.abs(vel) * traj_dt.unsqueeze(-1) * cspace_distance_weight, dim=-1), dim=-1
|
||||
@@ -40,24 +41,25 @@ def compute_path_length(vel, traj_dt, cspace_distance_weight):
|
||||
return pl
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def compute_path_length_cost(vel, cspace_distance_weight):
|
||||
pl = torch.sum(torch.sum(torch.abs(vel) * cspace_distance_weight, dim=-1), dim=-1)
|
||||
return pl
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def smooth_cost(abs_acc, abs_jerk, opt_dt):
|
||||
# acc = torch.max(torch.max(abs_acc, dim=-1)[0], dim=-1)[0]
|
||||
# jerk = torch.max(torch.max(abs_jerk, dim=-1)[0], dim=-1)[0]
|
||||
jerk = torch.mean(torch.max(abs_jerk, dim=-1)[0], dim=-1)
|
||||
mean_acc = torch.mean(torch.max(abs_acc, dim=-1)[0], dim=-1) # [0]
|
||||
a = (jerk * 0.001) + 5.0 * opt_dt + (mean_acc * 0.01)
|
||||
a = (jerk * 0.001) + 10.0 * opt_dt + (mean_acc * 0.01)
|
||||
# a = (jerk * 0.001) + 50.0 * opt_dt + (mean_acc * 0.01)
|
||||
|
||||
return a
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def compute_smoothness(
|
||||
vel: torch.Tensor,
|
||||
acc: torch.Tensor,
|
||||
@@ -104,7 +106,7 @@ def compute_smoothness(
|
||||
return (acc_label, smooth_cost(abs_acc, abs_jerk, dt_score))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def compute_smoothness_opt_dt(
|
||||
vel, acc, jerk, max_vel: torch.Tensor, max_acc: float, max_jerk: float, opt_dt: torch.Tensor
|
||||
):
|
||||
|
||||
@@ -34,6 +34,7 @@ from curobo.types.robot import JointState, RobotConfig
|
||||
from curobo.types.tensor import T_BDOF, T_BValue_bool, T_BValue_float
|
||||
from curobo.util.logger import log_error, log_warn
|
||||
from curobo.util.sample_lib import HaltonGenerator
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util_file import (
|
||||
get_robot_configs_path,
|
||||
get_task_configs_path,
|
||||
@@ -1010,7 +1011,7 @@ class IKSolver(IKSolverConfig):
|
||||
]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def get_success(
|
||||
feasible,
|
||||
position_error,
|
||||
@@ -1028,7 +1029,7 @@ def get_success(
|
||||
return success
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def get_result(
|
||||
pose_error,
|
||||
position_error,
|
||||
|
||||
@@ -197,7 +197,7 @@ class MotionGenConfig:
|
||||
smooth_weight: List[float] = None,
|
||||
finetune_smooth_weight: Optional[List[float]] = None,
|
||||
state_finite_difference_mode: Optional[str] = None,
|
||||
finetune_dt_scale: float = 0.98,
|
||||
finetune_dt_scale: float = 0.95,
|
||||
maximum_trajectory_time: Optional[float] = None,
|
||||
maximum_trajectory_dt: float = 0.1,
|
||||
velocity_scale: Optional[Union[List[float], float]] = None,
|
||||
@@ -2086,6 +2086,7 @@ class MotionGen(MotionGenConfig):
|
||||
# self.trajopt_solver.compute_metrics(not og_evaluate, og_evaluate)
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["trajopt_result"] = traj_result
|
||||
|
||||
# run finetune
|
||||
if plan_config.enable_finetune_trajopt and torch.count_nonzero(traj_result.success) > 0:
|
||||
with profiler.record_function("motion_gen/finetune_trajopt"):
|
||||
@@ -2113,6 +2114,9 @@ class MotionGen(MotionGenConfig):
|
||||
traj_result.solve_time = og_solve_time
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["finetune_trajopt_result"] = traj_result
|
||||
elif plan_config.enable_finetune_trajopt:
|
||||
traj_result.success = traj_result.success[:, 0]
|
||||
|
||||
result.success = traj_result.success
|
||||
|
||||
result.interpolated_plan = traj_result.interpolated_solution
|
||||
@@ -2190,7 +2194,7 @@ class MotionGen(MotionGenConfig):
|
||||
# warm up js_trajopt:
|
||||
goal_state = start_state.clone()
|
||||
goal_state.position[..., warmup_joint_index] += warmup_joint_delta
|
||||
for _ in range(2):
|
||||
for _ in range(3):
|
||||
self.plan_single_js(start_state, goal_state, MotionGenPlanConfig(max_attempts=1))
|
||||
if enable_graph:
|
||||
start_state = JointState.from_position(
|
||||
@@ -2214,7 +2218,7 @@ class MotionGen(MotionGenConfig):
|
||||
if n_goalset == -1:
|
||||
retract_pose = Pose(state.ee_pos_seq, quaternion=state.ee_quat_seq)
|
||||
start_state.position[..., warmup_joint_index] += warmup_joint_delta
|
||||
for _ in range(2):
|
||||
for _ in range(3):
|
||||
self.plan_single(
|
||||
start_state,
|
||||
retract_pose,
|
||||
@@ -2243,7 +2247,7 @@ class MotionGen(MotionGenConfig):
|
||||
quaternion=state.ee_quat_seq.repeat(n_goalset, 1).view(1, n_goalset, 4),
|
||||
)
|
||||
start_state.position[..., warmup_joint_index] += warmup_joint_delta
|
||||
for _ in range(2):
|
||||
for _ in range(3):
|
||||
self.plan_goalset(
|
||||
start_state,
|
||||
retract_pose,
|
||||
@@ -2278,7 +2282,7 @@ class MotionGen(MotionGenConfig):
|
||||
retract_pose = Pose(state.ee_pos_seq, quaternion=state.ee_quat_seq)
|
||||
start_state.position[..., warmup_joint_index] += warmup_joint_delta
|
||||
|
||||
for _ in range(2):
|
||||
for _ in range(3):
|
||||
if batch_env_mode:
|
||||
self.plan_batch_env(
|
||||
start_state,
|
||||
@@ -2307,7 +2311,7 @@ class MotionGen(MotionGenConfig):
|
||||
.contiguous(),
|
||||
)
|
||||
start_state.position[..., warmup_joint_index] += warmup_joint_delta
|
||||
for _ in range(2):
|
||||
for _ in range(3):
|
||||
if batch_env_mode:
|
||||
self.plan_batch_env_goalset(
|
||||
start_state,
|
||||
|
||||
@@ -41,6 +41,7 @@ from curobo.types.robot import JointState, RobotConfig
|
||||
from curobo.types.tensor import T_BDOF, T_DOF, T_BValue_bool, T_BValue_float
|
||||
from curobo.util.helpers import list_idx_if_not_none
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator, is_torch_compile_available
|
||||
from curobo.util.trajectory import (
|
||||
InterpolateType,
|
||||
calculate_dt_no_clamp,
|
||||
@@ -877,24 +878,37 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
result.metrics.goalset_index = metrics.goalset_index
|
||||
|
||||
st_time = time.time()
|
||||
feasible = torch.all(result.metrics.feasible, dim=-1)
|
||||
if result.metrics.cspace_error is None and result.metrics.position_error is None:
|
||||
raise log_error("convergence check requires either goal_pose or goal_state")
|
||||
|
||||
if result.metrics.position_error is not None:
|
||||
converge = torch.logical_and(
|
||||
result.metrics.position_error[..., -1] <= self.position_threshold,
|
||||
result.metrics.rotation_error[..., -1] <= self.rotation_threshold,
|
||||
)
|
||||
elif result.metrics.cspace_error is not None:
|
||||
converge = result.metrics.cspace_error[..., -1] <= self.cspace_threshold
|
||||
else:
|
||||
raise ValueError("convergence check requires either goal_pose or goal_state")
|
||||
success = jit_feasible_success(
|
||||
result.metrics.feasible,
|
||||
result.metrics.position_error,
|
||||
result.metrics.rotation_error,
|
||||
result.metrics.cspace_error,
|
||||
self.position_threshold,
|
||||
self.rotation_threshold,
|
||||
self.cspace_threshold,
|
||||
)
|
||||
if False:
|
||||
feasible = torch.all(result.metrics.feasible, dim=-1)
|
||||
|
||||
success = torch.logical_and(feasible, converge)
|
||||
if result.metrics.position_error is not None:
|
||||
converge = torch.logical_and(
|
||||
result.metrics.position_error[..., -1] <= self.position_threshold,
|
||||
result.metrics.rotation_error[..., -1] <= self.rotation_threshold,
|
||||
)
|
||||
elif result.metrics.cspace_error is not None:
|
||||
converge = result.metrics.cspace_error[..., -1] <= self.cspace_threshold
|
||||
else:
|
||||
raise ValueError("convergence check requires either goal_pose or goal_state")
|
||||
|
||||
success = torch.logical_and(feasible, converge)
|
||||
if return_all_solutions:
|
||||
traj_result = TrajResult(
|
||||
success=success,
|
||||
goal=goal,
|
||||
solution=result.action.scale(self.solver_dt / opt_dt.view(-1, 1, 1)),
|
||||
solution=result.action.scale_by_dt(self.solver_dt_tensor, opt_dt.view(-1, 1, 1)),
|
||||
seed=seed_traj,
|
||||
position_error=result.metrics.position_error,
|
||||
rotation_error=result.metrics.rotation_error,
|
||||
@@ -928,49 +942,86 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
)
|
||||
|
||||
with profiler.record_function("trajopt/best_select"):
|
||||
success[~smooth_label] = False
|
||||
# get the best solution:
|
||||
if result.metrics.pose_error is not None:
|
||||
convergence_error = result.metrics.pose_error[..., -1]
|
||||
elif result.metrics.cspace_error is not None:
|
||||
convergence_error = result.metrics.cspace_error[..., -1]
|
||||
if True: # not get_torch_jit_decorator() == torch.jit.script:
|
||||
# This only works if torch compile is available:
|
||||
(
|
||||
idx,
|
||||
position_error,
|
||||
rotation_error,
|
||||
cspace_error,
|
||||
goalset_index,
|
||||
opt_dt,
|
||||
success,
|
||||
) = jit_trajopt_best_select(
|
||||
success,
|
||||
smooth_label,
|
||||
result.metrics.cspace_error,
|
||||
result.metrics.pose_error,
|
||||
result.metrics.position_error,
|
||||
result.metrics.rotation_error,
|
||||
result.metrics.goalset_index,
|
||||
result.metrics.cost,
|
||||
smooth_cost,
|
||||
batch_mode,
|
||||
goal.batch,
|
||||
num_seeds,
|
||||
self._col,
|
||||
opt_dt,
|
||||
)
|
||||
if batch_mode:
|
||||
last_tstep = [last_tstep[i] for i in idx]
|
||||
else:
|
||||
last_tstep = [last_tstep[idx.item()]]
|
||||
best_act_seq = result.action[idx]
|
||||
best_raw_action = result.raw_action[idx]
|
||||
interpolated_traj = interpolated_trajs[idx]
|
||||
|
||||
else:
|
||||
raise ValueError("convergence check requires either goal_pose or goal_state")
|
||||
running_cost = torch.mean(result.metrics.cost, dim=-1) * 0.0001
|
||||
error = convergence_error + smooth_cost + running_cost
|
||||
error[~success] += 10000.0
|
||||
if batch_mode:
|
||||
idx = torch.argmin(error.view(goal.batch, num_seeds), dim=-1)
|
||||
idx = idx + num_seeds * self._col
|
||||
last_tstep = [last_tstep[i] for i in idx]
|
||||
success = success[idx]
|
||||
else:
|
||||
idx = torch.argmin(error, dim=0)
|
||||
success[~smooth_label] = False
|
||||
# get the best solution:
|
||||
if result.metrics.pose_error is not None:
|
||||
convergence_error = result.metrics.pose_error[..., -1]
|
||||
elif result.metrics.cspace_error is not None:
|
||||
convergence_error = result.metrics.cspace_error[..., -1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"convergence check requires either goal_pose or goal_state"
|
||||
)
|
||||
running_cost = torch.mean(result.metrics.cost, dim=-1) * 0.0001
|
||||
error = convergence_error + smooth_cost + running_cost
|
||||
error[~success] += 10000.0
|
||||
if batch_mode:
|
||||
idx = torch.argmin(error.view(goal.batch, num_seeds), dim=-1)
|
||||
idx = idx + num_seeds * self._col
|
||||
last_tstep = [last_tstep[i] for i in idx]
|
||||
success = success[idx]
|
||||
else:
|
||||
idx = torch.argmin(error, dim=0)
|
||||
|
||||
last_tstep = [last_tstep[idx.item()]]
|
||||
success = success[idx : idx + 1]
|
||||
last_tstep = [last_tstep[idx.item()]]
|
||||
success = success[idx : idx + 1]
|
||||
|
||||
best_act_seq = result.action[idx]
|
||||
best_raw_action = result.raw_action[idx]
|
||||
interpolated_traj = interpolated_trajs[idx]
|
||||
goalset_index = position_error = rotation_error = cspace_error = None
|
||||
if result.metrics.position_error is not None:
|
||||
position_error = result.metrics.position_error[idx, -1]
|
||||
if result.metrics.rotation_error is not None:
|
||||
rotation_error = result.metrics.rotation_error[idx, -1]
|
||||
if result.metrics.cspace_error is not None:
|
||||
cspace_error = result.metrics.cspace_error[idx, -1]
|
||||
if result.metrics.goalset_index is not None:
|
||||
goalset_index = result.metrics.goalset_index[idx, -1]
|
||||
best_act_seq = result.action[idx]
|
||||
best_raw_action = result.raw_action[idx]
|
||||
interpolated_traj = interpolated_trajs[idx]
|
||||
goalset_index = position_error = rotation_error = cspace_error = None
|
||||
if result.metrics.position_error is not None:
|
||||
position_error = result.metrics.position_error[idx, -1]
|
||||
if result.metrics.rotation_error is not None:
|
||||
rotation_error = result.metrics.rotation_error[idx, -1]
|
||||
if result.metrics.cspace_error is not None:
|
||||
cspace_error = result.metrics.cspace_error[idx, -1]
|
||||
if result.metrics.goalset_index is not None:
|
||||
goalset_index = result.metrics.goalset_index[idx, -1]
|
||||
|
||||
opt_dt = opt_dt[idx]
|
||||
opt_dt = opt_dt[idx]
|
||||
if self.sync_cuda_time:
|
||||
torch.cuda.synchronize()
|
||||
if len(best_act_seq.shape) == 3:
|
||||
opt_dt_v = opt_dt.view(-1, 1, 1)
|
||||
else:
|
||||
opt_dt_v = opt_dt.view(1, 1)
|
||||
opt_solution = best_act_seq.scale(self.solver_dt / opt_dt_v)
|
||||
opt_solution = best_act_seq.scale_by_dt(self.solver_dt_tensor, opt_dt_v)
|
||||
select_time = time.time() - st_time
|
||||
debug_info = None
|
||||
if self.store_debug_in_result:
|
||||
@@ -1174,7 +1225,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
self._max_joint_vel,
|
||||
self._max_joint_acc,
|
||||
self._max_joint_jerk,
|
||||
self.solver_dt,
|
||||
self.solver_dt_tensor,
|
||||
kind=self.interpolation_type,
|
||||
tensor_args=self.tensor_args,
|
||||
out_traj_state=self._interpolated_traj_buffer,
|
||||
@@ -1224,7 +1275,12 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
|
||||
@property
|
||||
def solver_dt(self):
|
||||
return self.solver.safety_rollout.dynamics_model.dt_traj_params.base_dt
|
||||
return self.solver.safety_rollout.dynamics_model.traj_dt[0]
|
||||
# return self.solver.safety_rollout.dynamics_model.dt_traj_params.base_dt
|
||||
|
||||
@property
|
||||
def solver_dt_tensor(self):
|
||||
return self.solver.safety_rollout.dynamics_model.traj_dt[0]
|
||||
|
||||
def update_solver_dt(
|
||||
self,
|
||||
@@ -1254,3 +1310,79 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
for rollout in rollouts
|
||||
if isinstance(rollout, ArmReacher)
|
||||
]
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def jit_feasible_success(
|
||||
feasible,
|
||||
position_error: Union[torch.Tensor, None],
|
||||
rotation_error: Union[torch.Tensor, None],
|
||||
cspace_error: Union[torch.Tensor, None],
|
||||
position_threshold: float,
|
||||
rotation_threshold: float,
|
||||
cspace_threshold: float,
|
||||
):
|
||||
feasible = torch.all(feasible, dim=-1)
|
||||
converge = feasible
|
||||
if position_error is not None and rotation_error is not None:
|
||||
converge = torch.logical_and(
|
||||
position_error[..., -1] <= position_threshold,
|
||||
rotation_error[..., -1] <= rotation_threshold,
|
||||
)
|
||||
elif cspace_error is not None:
|
||||
converge = cspace_error[..., -1] <= cspace_threshold
|
||||
|
||||
success = torch.logical_and(feasible, converge)
|
||||
return success
|
||||
|
||||
|
||||
@get_torch_jit_decorator(only_valid_for_compile=True)
|
||||
def jit_trajopt_best_select(
|
||||
success,
|
||||
smooth_label,
|
||||
cspace_error: Union[torch.Tensor, None],
|
||||
pose_error: Union[torch.Tensor, None],
|
||||
position_error: Union[torch.Tensor, None],
|
||||
rotation_error: Union[torch.Tensor, None],
|
||||
goalset_index: Union[torch.Tensor, None],
|
||||
cost,
|
||||
smooth_cost,
|
||||
batch_mode: bool,
|
||||
batch: int,
|
||||
num_seeds: int,
|
||||
col,
|
||||
opt_dt,
|
||||
):
|
||||
success[~smooth_label] = False
|
||||
convergence_error = 0
|
||||
# get the best solution:
|
||||
if pose_error is not None:
|
||||
convergence_error = pose_error[..., -1]
|
||||
elif cspace_error is not None:
|
||||
convergence_error = cspace_error[..., -1]
|
||||
|
||||
running_cost = torch.mean(cost, dim=-1) * 0.0001
|
||||
error = convergence_error + smooth_cost + running_cost
|
||||
error[~success] += 10000.0
|
||||
if batch_mode:
|
||||
idx = torch.argmin(error.view(batch, num_seeds), dim=-1)
|
||||
idx = idx + num_seeds * col
|
||||
success = success[idx]
|
||||
else:
|
||||
idx = torch.argmin(error, dim=0)
|
||||
|
||||
success = success[idx : idx + 1]
|
||||
|
||||
# goalset_index = position_error = rotation_error = cspace_error = None
|
||||
if position_error is not None:
|
||||
position_error = position_error[idx, -1]
|
||||
if rotation_error is not None:
|
||||
rotation_error = rotation_error[idx, -1]
|
||||
if cspace_error is not None:
|
||||
cspace_error = cspace_error[idx, -1]
|
||||
if goalset_index is not None:
|
||||
goalset_index = goalset_index[idx, -1]
|
||||
|
||||
opt_dt = opt_dt[idx]
|
||||
|
||||
return idx, position_error, rotation_error, cspace_error, goalset_index, opt_dt, success
|
||||
|
||||
Reference in New Issue
Block a user