Significantly improved convergence for mesh and cuboid, new ESDF collision.

This commit is contained in:
Balakumar Sundaralingam
2024-03-18 11:19:48 -07:00
parent 286b3820a5
commit b1f63e8778
100 changed files with 7587 additions and 2589 deletions

View File

@@ -40,7 +40,7 @@ from curobo.types.base import TensorDeviceType
from curobo.types.robot import CSpaceConfig, RobotConfig
from curobo.types.state import JointState
from curobo.util.logger import log_error, log_info, log_warn
from curobo.util.tensor_util import cat_sum
from curobo.util.tensor_util import cat_sum, cat_sum_horizon
@dataclass
@@ -104,10 +104,10 @@ class ArmCostConfig:
@dataclass
class ArmBaseConfig(RolloutConfig):
model_cfg: KinematicModelConfig
cost_cfg: ArmCostConfig
constraint_cfg: ArmCostConfig
convergence_cfg: ArmCostConfig
model_cfg: Optional[KinematicModelConfig] = None
cost_cfg: Optional[ArmCostConfig] = None
constraint_cfg: Optional[ArmCostConfig] = None
convergence_cfg: Optional[ArmCostConfig] = None
world_coll_checker: Optional[WorldCollision] = None
@staticmethod
@@ -322,7 +322,9 @@ class ArmBase(RolloutBase, ArmBaseConfig):
self.null_convergence = DistCost(self.convergence_cfg.null_space_cfg)
# set start state:
start_state = torch.randn((1, self.dynamics_model.d_state), **vars(self.tensor_args))
start_state = torch.randn(
(1, self.dynamics_model.d_state), **(self.tensor_args.as_torch_dict())
)
self._start_state = JointState(
position=start_state[:, : self.dynamics_model.d_dof],
velocity=start_state[:, : self.dynamics_model.d_dof],
@@ -366,9 +368,11 @@ class ArmBase(RolloutBase, ArmBaseConfig):
)
cost_list.append(coll_cost)
if return_list:
return cost_list
cost = cat_sum(cost_list)
if self.sum_horizon:
cost = cat_sum_horizon(cost_list)
else:
cost = cat_sum(cost_list)
return cost
def constraint_fn(

View File

@@ -10,7 +10,7 @@
#
# Standard Library
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, List, Optional
# Third Party
import torch
@@ -29,8 +29,9 @@ from curobo.types.base import TensorDeviceType
from curobo.types.robot import RobotConfig
from curobo.types.tensor import T_BValue_float, T_BValue_int
from curobo.util.helpers import list_idx_if_not_none
from curobo.util.logger import log_error, log_info
from curobo.util.tensor_util import cat_max, cat_sum
from curobo.util.logger import log_error, log_info, log_warn
from curobo.util.tensor_util import cat_max
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from .arm_base import ArmBase, ArmBaseConfig, ArmCostConfig
@@ -145,7 +146,7 @@ class ArmReacherConfig(ArmBaseConfig):
)
@torch.jit.script
@get_torch_jit_decorator()
def _compute_g_dist_jit(rot_err_norm, goal_dist):
# goal_cost = goal_cost.view(cost.shape)
# rot_err_norm = rot_err_norm.view(cost.shape)
@@ -319,7 +320,12 @@ class ArmReacher(ArmBase, ArmReacherConfig):
g_dist,
)
cost_list.append(z_vel)
cost = cat_sum(cost_list)
with profiler.record_function("cat_sum"):
if self.sum_horizon:
cost = cat_sum_horizon_reacher(cost_list)
else:
cost = cat_sum_reacher(cost_list)
return cost
def convergence_fn(
@@ -466,3 +472,15 @@ class ArmReacher(ArmBase, ArmReacherConfig):
)
for x in pose_costs
]
@get_torch_jit_decorator()
def cat_sum_reacher(tensor_list: List[torch.Tensor]):
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=0)
return cat_tensor
@get_torch_jit_decorator()
def cat_sum_horizon_reacher(tensor_list: List[torch.Tensor]):
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=(0, -1))
return cat_tensor

View File

@@ -21,6 +21,7 @@ import warp as wp
from curobo.cuda_robot_model.types import JointLimits
from curobo.types.robot import JointState
from curobo.types.tensor import T_DOF
from curobo.util.torch_utils import get_torch_jit_decorator
from curobo.util.warp import init_warp
# Local Folder
@@ -267,7 +268,7 @@ class BoundCost(CostBase, BoundCostConfig):
return super().update_dt(dt)
@torch.jit.script
@get_torch_jit_decorator()
def forward_bound_cost(p, lower_bounds, upper_bounds, weight):
# c = weight * torch.sum(torch.nn.functional.relu(torch.max(lower_bounds - p, p - upper_bounds)), dim=-1)
@@ -281,7 +282,7 @@ def forward_bound_cost(p, lower_bounds, upper_bounds, weight):
return c
@torch.jit.script
@get_torch_jit_decorator()
def forward_all_bound_cost(
p,
v,

View File

@@ -18,6 +18,7 @@ import torch
import warp as wp
# CuRobo
from curobo.util.torch_utils import get_torch_jit_decorator
from curobo.util.warp import init_warp
# Local Folder
@@ -41,32 +42,32 @@ class DistCostConfig(CostConfig):
return super().__post_init__()
@torch.jit.script
@get_torch_jit_decorator()
def L2_DistCost_jit(vec_weight, disp_vec):
return torch.norm(vec_weight * disp_vec, p=2, dim=-1, keepdim=False)
@torch.jit.script
@get_torch_jit_decorator()
def fwd_SQL2_DistCost_jit(vec_weight, disp_vec):
return torch.sum(torch.square(vec_weight * disp_vec), dim=-1, keepdim=False)
@torch.jit.script
@get_torch_jit_decorator()
def fwd_L1_DistCost_jit(vec_weight, disp_vec):
return torch.sum(torch.abs(vec_weight * disp_vec), dim=-1, keepdim=False)
@torch.jit.script
@get_torch_jit_decorator()
def L2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
return torch.norm(weight * vec_weight * (g_vec - c_vec), p=2, dim=-1, keepdim=False)
@torch.jit.script
@get_torch_jit_decorator()
def fwd_SQL2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
return torch.sum(torch.square(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
@torch.jit.script
@get_torch_jit_decorator()
def fwd_L1_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
return torch.sum(torch.abs(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)

View File

@@ -19,6 +19,8 @@ import torch
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollision
from curobo.rollout.cost.cost_base import CostBase, CostConfig
from curobo.rollout.dynamics_model.integration_utils import interpolate_kernel, sum_matrix
from curobo.util.logger import log_info
from curobo.util.torch_utils import get_torch_jit_decorator
@dataclass
@@ -48,7 +50,11 @@ class PrimitiveCollisionCostConfig(CostConfig):
#: post optimization interpolation to not hit any obstacles.
activation_distance: Union[torch.Tensor, float] = 0.0
#: Setting this flag to true will sum the distance colliding obstacles.
sum_collisions: bool = True
#: Setting this flag to true will sum the distance across spheres of the robot.
#: Setting to False will only take the max distance
sum_distance: bool = True
def __post_init__(self):
@@ -103,6 +109,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
self._collision_query_buffer.update_buffer_shape(
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
)
if not self.sum_distance:
log_info("sum_distance=False will be slower than sum_distance=True")
self.return_loss = True
dist = self.sweep_check_fn(
robot_spheres_in,
self._collision_query_buffer,
@@ -115,9 +124,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
return_loss=self.return_loss,
)
if self.classify:
cost = weight_collision(dist, self.weight, self.sum_distance)
cost = weight_collision(dist, self.sum_distance)
else:
cost = weight_distance(dist, self.weight, self.sum_distance)
cost = weight_distance(dist, self.sum_distance)
return cost
def sweep_fn(self, robot_spheres_in, env_query_idx: Optional[torch.Tensor] = None):
@@ -140,6 +149,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
self._collision_query_buffer.update_buffer_shape(
sampled_spheres.shape, self.tensor_args, self.world_coll_checker.collision_types
)
if not self.sum_distance:
log_info("sum_distance=False will be slower than sum_distance=True")
self.return_loss = True
dist = self.coll_check_fn(
sampled_spheres.contiguous(),
self._collision_query_buffer,
@@ -151,9 +163,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
dist = dist.view(batch_size, new_horizon, n_spheres)
if self.classify:
cost = weight_sweep_collision(self.int_sum_mat, dist, self.weight, self.sum_distance)
cost = weight_sweep_collision(self.int_sum_mat, dist, self.sum_distance)
else:
cost = weight_sweep_distance(self.int_sum_mat, dist, self.weight, self.sum_distance)
cost = weight_sweep_distance(self.int_sum_mat, dist, self.sum_distance)
return cost
@@ -161,6 +173,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
self._collision_query_buffer.update_buffer_shape(
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
)
if not self.sum_distance:
log_info("sum_distance=False will be slower than sum_distance=True")
self.return_loss = True
dist = self.coll_check_fn(
robot_spheres_in,
self._collision_query_buffer,
@@ -168,12 +183,13 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
env_query_idx=env_query_idx,
activation_distance=self.activation_distance,
return_loss=self.return_loss,
sum_collisions=self.sum_collisions,
)
if self.classify:
cost = weight_collision(dist, self.weight, self.sum_distance)
cost = weight_collision(dist, self.sum_distance)
else:
cost = weight_distance(dist, self.weight, self.sum_distance)
cost = weight_distance(dist, self.sum_distance)
return cost
def update_dt(self, dt: Union[float, torch.Tensor]):
@@ -184,31 +200,43 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
return self._collision_query_buffer.get_gradient_buffer()
@torch.jit.script
def weight_sweep_distance(int_mat, dist, weight, sum_cost: bool):
dist = torch.sum(dist, dim=-1)
@get_torch_jit_decorator()
def weight_sweep_distance(int_mat, dist, sum_cost: bool):
if sum_cost:
dist = torch.sum(dist, dim=-1)
else:
dist = torch.max(dist, dim=-1)[0]
dist = dist @ int_mat
return dist
@torch.jit.script
def weight_sweep_collision(int_mat, dist, weight, sum_cost: bool):
dist = torch.sum(dist, dim=-1)
@get_torch_jit_decorator()
def weight_sweep_collision(int_mat, dist, sum_cost: bool):
if sum_cost:
dist = torch.sum(dist, dim=-1)
else:
dist = torch.max(dist, dim=-1)[0]
dist = torch.where(dist > 0, dist + 1.0, dist)
dist = dist @ int_mat
return dist
@torch.jit.script
def weight_distance(dist, weight, sum_cost: bool):
@get_torch_jit_decorator()
def weight_distance(dist, sum_cost: bool):
if sum_cost:
dist = torch.sum(dist, dim=-1)
else:
dist = torch.max(dist, dim=-1)[0]
return dist
@torch.jit.script
def weight_collision(dist, weight, sum_cost: bool):
@get_torch_jit_decorator()
def weight_collision(dist, sum_cost: bool):
if sum_cost:
dist = torch.sum(dist, dim=-1)
else:
dist = torch.max(dist, dim=-1)[0]
dist = torch.where(dist > 0, dist + 1.0, dist)
return dist

View File

@@ -17,6 +17,7 @@ import torch
# CuRobo
from curobo.rollout.dynamics_model.kinematic_model import TimeTrajConfig
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from .cost_base import CostBase, CostConfig
@@ -72,7 +73,7 @@ class StopCost(CostBase, StopCostConfig):
return cost
@torch.jit.script
@get_torch_jit_decorator()
def velocity_cost(vels, weight, max_vel):
vel_abs = torch.abs(vels)
vel_abs = torch.nn.functional.relu(vel_abs - max_vel[: vels.shape[1]])

View File

@@ -13,11 +13,14 @@
# Third Party
import torch
# CuRobo
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from .cost_base import CostBase, CostConfig
@torch.jit.script
@get_torch_jit_decorator()
def st_cost(ee_pos_batch, vec_weight, weight):
ee_plus_one = torch.roll(ee_pos_batch, 1, dims=1)

View File

@@ -11,11 +11,14 @@
# Third Party
import torch
# CuRobo
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from .cost_base import CostBase
@torch.jit.script
@get_torch_jit_decorator()
def squared_sum(cost: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
# return weight * torch.square(torch.linalg.norm(cost, dim=-1, ord=1))
# return weight * torch.sum(torch.square(cost), dim=-1)
@@ -24,7 +27,7 @@ def squared_sum(cost: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
return torch.sum(torch.square(cost) * weight, dim=-1)
@torch.jit.script
@get_torch_jit_decorator()
def run_squared_sum(
cost: torch.Tensor, weight: torch.Tensor, run_weight: torch.Tensor
) -> torch.Tensor:
@@ -35,13 +38,13 @@ def run_squared_sum(
# return torch.sum(torch.square(cost), dim=-1) * weight * run_weight
@torch.jit.script
@get_torch_jit_decorator()
def backward_squared_sum(cost_vec, w):
return 2.0 * w * cost_vec # * g_out.unsqueeze(-1)
# return w * g_out.unsqueeze(-1)
@torch.jit.script
@get_torch_jit_decorator()
def backward_run_squared_sum(cost_vec, w, r_w):
return 2.0 * w * r_w.unsqueeze(-1) * cost_vec # * g_out.unsqueeze(-1)
# return w * r_w.unsqueeze(-1) * cost_vec * g_out.unsqueeze(-1)

View File

@@ -25,6 +25,7 @@ from curobo.curobolib.tensor_step import (
)
from curobo.types.base import TensorDeviceType
from curobo.types.robot import JointState
from curobo.util.torch_utils import get_torch_jit_decorator
def build_clique_matrix(horizon, dt, device="cpu", dtype=torch.float32):
@@ -154,7 +155,7 @@ def build_start_state_mask(horizon, tensor_args: TensorDeviceType):
return mask, n_mask
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_jerk(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix=None):
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Optional[Tensor]) -> Tensor
@@ -176,7 +177,7 @@ def tensor_step_jerk(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_m
return state_seq
# @torch.jit.script
# @get_torch_jit_decorator()
def euler_integrate(q_0, u, diag_dt, integrate_matrix):
# q_new = q_0 + torch.matmul(integrate_matrix, torch.matmul(diag_dt, u))
q_new = q_0 + torch.matmul(integrate_matrix, u)
@@ -184,7 +185,7 @@ def euler_integrate(q_0, u, diag_dt, integrate_matrix):
return q_new
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_acc(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix=None):
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Optional[Tensor]) -> Tensor
# This is batch,n_dof
@@ -207,7 +208,7 @@ def tensor_step_acc(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_ma
return state_seq
@torch.jit.script
@get_torch_jit_decorator()
def jit_tensor_step_pos_clique_contiguous(pos_act, start_position, mask, n_mask, fd_1, fd_2, fd_3):
state_position = (start_position.unsqueeze(1).transpose(1, 2) @ mask.transpose(0, 1)) + (
pos_act.transpose(1, 2) @ n_mask.transpose(0, 1)
@@ -222,7 +223,7 @@ def jit_tensor_step_pos_clique_contiguous(pos_act, start_position, mask, n_mask,
return state_position, state_vel, state_acc, state_jerk
@torch.jit.script
@get_torch_jit_decorator()
def jit_tensor_step_pos_clique(pos_act, start_position, mask, n_mask, fd_1, fd_2, fd_3):
state_position = mask @ start_position.unsqueeze(1) + n_mask @ pos_act
state_vel = fd_1 @ state_position
@@ -231,7 +232,7 @@ def jit_tensor_step_pos_clique(pos_act, start_position, mask, n_mask, fd_1, fd_2
return state_position, state_vel, state_acc, state_jerk
@torch.jit.script
@get_torch_jit_decorator()
def jit_backward_pos_clique(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2, fd_3):
p_grad = (
grad_p
@@ -247,7 +248,7 @@ def jit_backward_pos_clique(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2,
return u_grad
@torch.jit.script
@get_torch_jit_decorator()
def jit_backward_pos_clique_contiguous(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2, fd_3):
p_grad = grad_p + (
grad_j.transpose(-1, -2) @ fd_3
@@ -532,7 +533,7 @@ class CliqueTensorStepIdxCentralDifferenceKernel(torch.autograd.Function):
start_position,
start_velocity,
start_acceleration,
start_idx,
start_idx.contiguous(),
traj_dt,
out_position.shape[0],
out_position.shape[1],
@@ -750,7 +751,7 @@ class AccelerationTensorStepIdxKernel(torch.autograd.Function):
return u_grad, None, None, None, None, None, None, None, None, None, None
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_pos_clique(
state: JointState,
act: torch.Tensor,
@@ -786,7 +787,7 @@ def step_acc_semi_euler(state, act, diag_dt, n_dofs, integrate_matrix):
return state_seq
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_acc_semi_euler(
state, act, state_seq, diag_dt, integrate_matrix, integrate_matrix_pos
):
@@ -806,7 +807,7 @@ def tensor_step_acc_semi_euler(
return state_seq
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_vel(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix):
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Tensor) -> Tensor
@@ -830,7 +831,7 @@ def tensor_step_vel(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_ma
return state_seq
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_pos(state, act, state_seq, fd_matrix):
# This is batch,n_dof
state_seq.position[:, 0, :] = state.position
@@ -850,7 +851,7 @@ def tensor_step_pos(state, act, state_seq, fd_matrix):
return state_seq
# @torch.jit.script
# @get_torch_jit_decorator()
def tensor_step_pos_ik(act, state_seq):
state_seq.position = act
return state_seq
@@ -869,7 +870,7 @@ def tensor_linspace(start_tensor, end_tensor, steps=10):
def sum_matrix(h, int_steps, tensor_args):
sum_mat = torch.zeros(((h - 1) * int_steps, h), **vars(tensor_args))
sum_mat = torch.zeros(((h - 1) * int_steps, h), **(tensor_args.as_torch_dict()))
for i in range(h - 1):
sum_mat[i * int_steps : i * int_steps + int_steps, i] = 1.0
# hack:

View File

@@ -19,6 +19,7 @@ import torch
# CuRobo
from curobo.types.base import TensorDeviceType
from curobo.types.robot import JointState
from curobo.util.torch_utils import get_torch_jit_decorator
# Local Folder
from .integration_utils import (
@@ -544,7 +545,7 @@ class TensorStepPositionCliqueKernel(TensorStepBase):
return new_signal
@torch.jit.script
@get_torch_jit_decorator(force_jit=True)
def filter_signal_jit(signal, kernel):
b, h, dof = signal.shape

View File

@@ -36,6 +36,7 @@ from curobo.util.helpers import list_idx_if_not_none
from curobo.util.logger import log_info
from curobo.util.sample_lib import HaltonGenerator
from curobo.util.tensor_util import copy_tensor
from curobo.util.torch_utils import get_torch_jit_decorator
@dataclass
@@ -298,9 +299,9 @@ class Goal(Sequence):
if self.goal_pose is not None:
self.goal_pose = self.goal_pose.to(tensor_args)
if self.goal_state is not None:
self.goal_state = self.goal_state.to(**vars(tensor_args))
self.goal_state = self.goal_state.to(**(tensor_args.as_torch_dict()))
if self.current_state is not None:
self.current_state = self.current_state.to(**vars(tensor_args))
self.current_state = self.current_state.to(**(tensor_args.as_torch_dict()))
return self
def copy_(self, goal: Goal, update_idx_buffers: bool = True):
@@ -350,6 +351,7 @@ class Goal(Sequence):
if ref_buffer is not None:
ref_buffer = ref_buffer.copy_(buffer)
else:
log_info("breaking reference")
ref_buffer = buffer.clone()
return ref_buffer
@@ -414,6 +416,7 @@ class Goal(Sequence):
@dataclass
class RolloutConfig:
tensor_args: TensorDeviceType
sum_horizon: bool = False
class RolloutBase:
@@ -578,7 +581,7 @@ class RolloutBase:
return q_js
@torch.jit.script
@get_torch_jit_decorator()
def tensor_repeat_seeds(tensor, num_seeds: int):
a = (
tensor.view(tensor.shape[0], 1, tensor.shape[-1])