Significantly improved convergence for mesh and cuboid, new ESDF collision.

This commit is contained in:
Balakumar Sundaralingam
2024-03-18 11:19:48 -07:00
parent 286b3820a5
commit b1f63e8778
100 changed files with 7587 additions and 2589 deletions

View File

@@ -41,6 +41,7 @@ from curobo.types.robot import JointState, RobotConfig
from curobo.types.tensor import T_BDOF, T_DOF, T_BValue_bool, T_BValue_float
from curobo.util.helpers import list_idx_if_not_none
from curobo.util.logger import log_error, log_info, log_warn
from curobo.util.torch_utils import get_torch_jit_decorator, is_torch_compile_available
from curobo.util.trajectory import (
InterpolateType,
calculate_dt_no_clamp,
@@ -877,24 +878,37 @@ class TrajOptSolver(TrajOptSolverConfig):
result.metrics.goalset_index = metrics.goalset_index
st_time = time.time()
feasible = torch.all(result.metrics.feasible, dim=-1)
if result.metrics.cspace_error is None and result.metrics.position_error is None:
raise log_error("convergence check requires either goal_pose or goal_state")
if result.metrics.position_error is not None:
converge = torch.logical_and(
result.metrics.position_error[..., -1] <= self.position_threshold,
result.metrics.rotation_error[..., -1] <= self.rotation_threshold,
)
elif result.metrics.cspace_error is not None:
converge = result.metrics.cspace_error[..., -1] <= self.cspace_threshold
else:
raise ValueError("convergence check requires either goal_pose or goal_state")
success = jit_feasible_success(
result.metrics.feasible,
result.metrics.position_error,
result.metrics.rotation_error,
result.metrics.cspace_error,
self.position_threshold,
self.rotation_threshold,
self.cspace_threshold,
)
if False:
feasible = torch.all(result.metrics.feasible, dim=-1)
success = torch.logical_and(feasible, converge)
if result.metrics.position_error is not None:
converge = torch.logical_and(
result.metrics.position_error[..., -1] <= self.position_threshold,
result.metrics.rotation_error[..., -1] <= self.rotation_threshold,
)
elif result.metrics.cspace_error is not None:
converge = result.metrics.cspace_error[..., -1] <= self.cspace_threshold
else:
raise ValueError("convergence check requires either goal_pose or goal_state")
success = torch.logical_and(feasible, converge)
if return_all_solutions:
traj_result = TrajResult(
success=success,
goal=goal,
solution=result.action.scale(self.solver_dt / opt_dt.view(-1, 1, 1)),
solution=result.action.scale_by_dt(self.solver_dt_tensor, opt_dt.view(-1, 1, 1)),
seed=seed_traj,
position_error=result.metrics.position_error,
rotation_error=result.metrics.rotation_error,
@@ -928,49 +942,86 @@ class TrajOptSolver(TrajOptSolverConfig):
)
with profiler.record_function("trajopt/best_select"):
success[~smooth_label] = False
# get the best solution:
if result.metrics.pose_error is not None:
convergence_error = result.metrics.pose_error[..., -1]
elif result.metrics.cspace_error is not None:
convergence_error = result.metrics.cspace_error[..., -1]
if True: # not get_torch_jit_decorator() == torch.jit.script:
# This only works if torch compile is available:
(
idx,
position_error,
rotation_error,
cspace_error,
goalset_index,
opt_dt,
success,
) = jit_trajopt_best_select(
success,
smooth_label,
result.metrics.cspace_error,
result.metrics.pose_error,
result.metrics.position_error,
result.metrics.rotation_error,
result.metrics.goalset_index,
result.metrics.cost,
smooth_cost,
batch_mode,
goal.batch,
num_seeds,
self._col,
opt_dt,
)
if batch_mode:
last_tstep = [last_tstep[i] for i in idx]
else:
last_tstep = [last_tstep[idx.item()]]
best_act_seq = result.action[idx]
best_raw_action = result.raw_action[idx]
interpolated_traj = interpolated_trajs[idx]
else:
raise ValueError("convergence check requires either goal_pose or goal_state")
running_cost = torch.mean(result.metrics.cost, dim=-1) * 0.0001
error = convergence_error + smooth_cost + running_cost
error[~success] += 10000.0
if batch_mode:
idx = torch.argmin(error.view(goal.batch, num_seeds), dim=-1)
idx = idx + num_seeds * self._col
last_tstep = [last_tstep[i] for i in idx]
success = success[idx]
else:
idx = torch.argmin(error, dim=0)
success[~smooth_label] = False
# get the best solution:
if result.metrics.pose_error is not None:
convergence_error = result.metrics.pose_error[..., -1]
elif result.metrics.cspace_error is not None:
convergence_error = result.metrics.cspace_error[..., -1]
else:
raise ValueError(
"convergence check requires either goal_pose or goal_state"
)
running_cost = torch.mean(result.metrics.cost, dim=-1) * 0.0001
error = convergence_error + smooth_cost + running_cost
error[~success] += 10000.0
if batch_mode:
idx = torch.argmin(error.view(goal.batch, num_seeds), dim=-1)
idx = idx + num_seeds * self._col
last_tstep = [last_tstep[i] for i in idx]
success = success[idx]
else:
idx = torch.argmin(error, dim=0)
last_tstep = [last_tstep[idx.item()]]
success = success[idx : idx + 1]
last_tstep = [last_tstep[idx.item()]]
success = success[idx : idx + 1]
best_act_seq = result.action[idx]
best_raw_action = result.raw_action[idx]
interpolated_traj = interpolated_trajs[idx]
goalset_index = position_error = rotation_error = cspace_error = None
if result.metrics.position_error is not None:
position_error = result.metrics.position_error[idx, -1]
if result.metrics.rotation_error is not None:
rotation_error = result.metrics.rotation_error[idx, -1]
if result.metrics.cspace_error is not None:
cspace_error = result.metrics.cspace_error[idx, -1]
if result.metrics.goalset_index is not None:
goalset_index = result.metrics.goalset_index[idx, -1]
best_act_seq = result.action[idx]
best_raw_action = result.raw_action[idx]
interpolated_traj = interpolated_trajs[idx]
goalset_index = position_error = rotation_error = cspace_error = None
if result.metrics.position_error is not None:
position_error = result.metrics.position_error[idx, -1]
if result.metrics.rotation_error is not None:
rotation_error = result.metrics.rotation_error[idx, -1]
if result.metrics.cspace_error is not None:
cspace_error = result.metrics.cspace_error[idx, -1]
if result.metrics.goalset_index is not None:
goalset_index = result.metrics.goalset_index[idx, -1]
opt_dt = opt_dt[idx]
opt_dt = opt_dt[idx]
if self.sync_cuda_time:
torch.cuda.synchronize()
if len(best_act_seq.shape) == 3:
opt_dt_v = opt_dt.view(-1, 1, 1)
else:
opt_dt_v = opt_dt.view(1, 1)
opt_solution = best_act_seq.scale(self.solver_dt / opt_dt_v)
opt_solution = best_act_seq.scale_by_dt(self.solver_dt_tensor, opt_dt_v)
select_time = time.time() - st_time
debug_info = None
if self.store_debug_in_result:
@@ -1174,7 +1225,7 @@ class TrajOptSolver(TrajOptSolverConfig):
self._max_joint_vel,
self._max_joint_acc,
self._max_joint_jerk,
self.solver_dt,
self.solver_dt_tensor,
kind=self.interpolation_type,
tensor_args=self.tensor_args,
out_traj_state=self._interpolated_traj_buffer,
@@ -1224,7 +1275,12 @@ class TrajOptSolver(TrajOptSolverConfig):
@property
def solver_dt(self):
return self.solver.safety_rollout.dynamics_model.dt_traj_params.base_dt
return self.solver.safety_rollout.dynamics_model.traj_dt[0]
# return self.solver.safety_rollout.dynamics_model.dt_traj_params.base_dt
@property
def solver_dt_tensor(self):
return self.solver.safety_rollout.dynamics_model.traj_dt[0]
def update_solver_dt(
self,
@@ -1254,3 +1310,79 @@ class TrajOptSolver(TrajOptSolverConfig):
for rollout in rollouts
if isinstance(rollout, ArmReacher)
]
@get_torch_jit_decorator()
def jit_feasible_success(
feasible,
position_error: Union[torch.Tensor, None],
rotation_error: Union[torch.Tensor, None],
cspace_error: Union[torch.Tensor, None],
position_threshold: float,
rotation_threshold: float,
cspace_threshold: float,
):
feasible = torch.all(feasible, dim=-1)
converge = feasible
if position_error is not None and rotation_error is not None:
converge = torch.logical_and(
position_error[..., -1] <= position_threshold,
rotation_error[..., -1] <= rotation_threshold,
)
elif cspace_error is not None:
converge = cspace_error[..., -1] <= cspace_threshold
success = torch.logical_and(feasible, converge)
return success
@get_torch_jit_decorator(only_valid_for_compile=True)
def jit_trajopt_best_select(
success,
smooth_label,
cspace_error: Union[torch.Tensor, None],
pose_error: Union[torch.Tensor, None],
position_error: Union[torch.Tensor, None],
rotation_error: Union[torch.Tensor, None],
goalset_index: Union[torch.Tensor, None],
cost,
smooth_cost,
batch_mode: bool,
batch: int,
num_seeds: int,
col,
opt_dt,
):
success[~smooth_label] = False
convergence_error = 0
# get the best solution:
if pose_error is not None:
convergence_error = pose_error[..., -1]
elif cspace_error is not None:
convergence_error = cspace_error[..., -1]
running_cost = torch.mean(cost, dim=-1) * 0.0001
error = convergence_error + smooth_cost + running_cost
error[~success] += 10000.0
if batch_mode:
idx = torch.argmin(error.view(batch, num_seeds), dim=-1)
idx = idx + num_seeds * col
success = success[idx]
else:
idx = torch.argmin(error, dim=0)
success = success[idx : idx + 1]
# goalset_index = position_error = rotation_error = cspace_error = None
if position_error is not None:
position_error = position_error[idx, -1]
if rotation_error is not None:
rotation_error = rotation_error[idx, -1]
if cspace_error is not None:
cspace_error = cspace_error[idx, -1]
if goalset_index is not None:
goalset_index = goalset_index[idx, -1]
opt_dt = opt_dt[idx]
return idx, position_error, rotation_error, cspace_error, goalset_index, opt_dt, success