Significantly improved convergence for mesh and cuboid, new ESDF collision.
This commit is contained in:
@@ -41,6 +41,7 @@ from curobo.types.robot import JointState, RobotConfig
|
||||
from curobo.types.tensor import T_BDOF, T_DOF, T_BValue_bool, T_BValue_float
|
||||
from curobo.util.helpers import list_idx_if_not_none
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator, is_torch_compile_available
|
||||
from curobo.util.trajectory import (
|
||||
InterpolateType,
|
||||
calculate_dt_no_clamp,
|
||||
@@ -877,24 +878,37 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
result.metrics.goalset_index = metrics.goalset_index
|
||||
|
||||
st_time = time.time()
|
||||
feasible = torch.all(result.metrics.feasible, dim=-1)
|
||||
if result.metrics.cspace_error is None and result.metrics.position_error is None:
|
||||
raise log_error("convergence check requires either goal_pose or goal_state")
|
||||
|
||||
if result.metrics.position_error is not None:
|
||||
converge = torch.logical_and(
|
||||
result.metrics.position_error[..., -1] <= self.position_threshold,
|
||||
result.metrics.rotation_error[..., -1] <= self.rotation_threshold,
|
||||
)
|
||||
elif result.metrics.cspace_error is not None:
|
||||
converge = result.metrics.cspace_error[..., -1] <= self.cspace_threshold
|
||||
else:
|
||||
raise ValueError("convergence check requires either goal_pose or goal_state")
|
||||
success = jit_feasible_success(
|
||||
result.metrics.feasible,
|
||||
result.metrics.position_error,
|
||||
result.metrics.rotation_error,
|
||||
result.metrics.cspace_error,
|
||||
self.position_threshold,
|
||||
self.rotation_threshold,
|
||||
self.cspace_threshold,
|
||||
)
|
||||
if False:
|
||||
feasible = torch.all(result.metrics.feasible, dim=-1)
|
||||
|
||||
success = torch.logical_and(feasible, converge)
|
||||
if result.metrics.position_error is not None:
|
||||
converge = torch.logical_and(
|
||||
result.metrics.position_error[..., -1] <= self.position_threshold,
|
||||
result.metrics.rotation_error[..., -1] <= self.rotation_threshold,
|
||||
)
|
||||
elif result.metrics.cspace_error is not None:
|
||||
converge = result.metrics.cspace_error[..., -1] <= self.cspace_threshold
|
||||
else:
|
||||
raise ValueError("convergence check requires either goal_pose or goal_state")
|
||||
|
||||
success = torch.logical_and(feasible, converge)
|
||||
if return_all_solutions:
|
||||
traj_result = TrajResult(
|
||||
success=success,
|
||||
goal=goal,
|
||||
solution=result.action.scale(self.solver_dt / opt_dt.view(-1, 1, 1)),
|
||||
solution=result.action.scale_by_dt(self.solver_dt_tensor, opt_dt.view(-1, 1, 1)),
|
||||
seed=seed_traj,
|
||||
position_error=result.metrics.position_error,
|
||||
rotation_error=result.metrics.rotation_error,
|
||||
@@ -928,49 +942,86 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
)
|
||||
|
||||
with profiler.record_function("trajopt/best_select"):
|
||||
success[~smooth_label] = False
|
||||
# get the best solution:
|
||||
if result.metrics.pose_error is not None:
|
||||
convergence_error = result.metrics.pose_error[..., -1]
|
||||
elif result.metrics.cspace_error is not None:
|
||||
convergence_error = result.metrics.cspace_error[..., -1]
|
||||
if True: # not get_torch_jit_decorator() == torch.jit.script:
|
||||
# This only works if torch compile is available:
|
||||
(
|
||||
idx,
|
||||
position_error,
|
||||
rotation_error,
|
||||
cspace_error,
|
||||
goalset_index,
|
||||
opt_dt,
|
||||
success,
|
||||
) = jit_trajopt_best_select(
|
||||
success,
|
||||
smooth_label,
|
||||
result.metrics.cspace_error,
|
||||
result.metrics.pose_error,
|
||||
result.metrics.position_error,
|
||||
result.metrics.rotation_error,
|
||||
result.metrics.goalset_index,
|
||||
result.metrics.cost,
|
||||
smooth_cost,
|
||||
batch_mode,
|
||||
goal.batch,
|
||||
num_seeds,
|
||||
self._col,
|
||||
opt_dt,
|
||||
)
|
||||
if batch_mode:
|
||||
last_tstep = [last_tstep[i] for i in idx]
|
||||
else:
|
||||
last_tstep = [last_tstep[idx.item()]]
|
||||
best_act_seq = result.action[idx]
|
||||
best_raw_action = result.raw_action[idx]
|
||||
interpolated_traj = interpolated_trajs[idx]
|
||||
|
||||
else:
|
||||
raise ValueError("convergence check requires either goal_pose or goal_state")
|
||||
running_cost = torch.mean(result.metrics.cost, dim=-1) * 0.0001
|
||||
error = convergence_error + smooth_cost + running_cost
|
||||
error[~success] += 10000.0
|
||||
if batch_mode:
|
||||
idx = torch.argmin(error.view(goal.batch, num_seeds), dim=-1)
|
||||
idx = idx + num_seeds * self._col
|
||||
last_tstep = [last_tstep[i] for i in idx]
|
||||
success = success[idx]
|
||||
else:
|
||||
idx = torch.argmin(error, dim=0)
|
||||
success[~smooth_label] = False
|
||||
# get the best solution:
|
||||
if result.metrics.pose_error is not None:
|
||||
convergence_error = result.metrics.pose_error[..., -1]
|
||||
elif result.metrics.cspace_error is not None:
|
||||
convergence_error = result.metrics.cspace_error[..., -1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"convergence check requires either goal_pose or goal_state"
|
||||
)
|
||||
running_cost = torch.mean(result.metrics.cost, dim=-1) * 0.0001
|
||||
error = convergence_error + smooth_cost + running_cost
|
||||
error[~success] += 10000.0
|
||||
if batch_mode:
|
||||
idx = torch.argmin(error.view(goal.batch, num_seeds), dim=-1)
|
||||
idx = idx + num_seeds * self._col
|
||||
last_tstep = [last_tstep[i] for i in idx]
|
||||
success = success[idx]
|
||||
else:
|
||||
idx = torch.argmin(error, dim=0)
|
||||
|
||||
last_tstep = [last_tstep[idx.item()]]
|
||||
success = success[idx : idx + 1]
|
||||
last_tstep = [last_tstep[idx.item()]]
|
||||
success = success[idx : idx + 1]
|
||||
|
||||
best_act_seq = result.action[idx]
|
||||
best_raw_action = result.raw_action[idx]
|
||||
interpolated_traj = interpolated_trajs[idx]
|
||||
goalset_index = position_error = rotation_error = cspace_error = None
|
||||
if result.metrics.position_error is not None:
|
||||
position_error = result.metrics.position_error[idx, -1]
|
||||
if result.metrics.rotation_error is not None:
|
||||
rotation_error = result.metrics.rotation_error[idx, -1]
|
||||
if result.metrics.cspace_error is not None:
|
||||
cspace_error = result.metrics.cspace_error[idx, -1]
|
||||
if result.metrics.goalset_index is not None:
|
||||
goalset_index = result.metrics.goalset_index[idx, -1]
|
||||
best_act_seq = result.action[idx]
|
||||
best_raw_action = result.raw_action[idx]
|
||||
interpolated_traj = interpolated_trajs[idx]
|
||||
goalset_index = position_error = rotation_error = cspace_error = None
|
||||
if result.metrics.position_error is not None:
|
||||
position_error = result.metrics.position_error[idx, -1]
|
||||
if result.metrics.rotation_error is not None:
|
||||
rotation_error = result.metrics.rotation_error[idx, -1]
|
||||
if result.metrics.cspace_error is not None:
|
||||
cspace_error = result.metrics.cspace_error[idx, -1]
|
||||
if result.metrics.goalset_index is not None:
|
||||
goalset_index = result.metrics.goalset_index[idx, -1]
|
||||
|
||||
opt_dt = opt_dt[idx]
|
||||
opt_dt = opt_dt[idx]
|
||||
if self.sync_cuda_time:
|
||||
torch.cuda.synchronize()
|
||||
if len(best_act_seq.shape) == 3:
|
||||
opt_dt_v = opt_dt.view(-1, 1, 1)
|
||||
else:
|
||||
opt_dt_v = opt_dt.view(1, 1)
|
||||
opt_solution = best_act_seq.scale(self.solver_dt / opt_dt_v)
|
||||
opt_solution = best_act_seq.scale_by_dt(self.solver_dt_tensor, opt_dt_v)
|
||||
select_time = time.time() - st_time
|
||||
debug_info = None
|
||||
if self.store_debug_in_result:
|
||||
@@ -1174,7 +1225,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
self._max_joint_vel,
|
||||
self._max_joint_acc,
|
||||
self._max_joint_jerk,
|
||||
self.solver_dt,
|
||||
self.solver_dt_tensor,
|
||||
kind=self.interpolation_type,
|
||||
tensor_args=self.tensor_args,
|
||||
out_traj_state=self._interpolated_traj_buffer,
|
||||
@@ -1224,7 +1275,12 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
|
||||
@property
|
||||
def solver_dt(self):
|
||||
return self.solver.safety_rollout.dynamics_model.dt_traj_params.base_dt
|
||||
return self.solver.safety_rollout.dynamics_model.traj_dt[0]
|
||||
# return self.solver.safety_rollout.dynamics_model.dt_traj_params.base_dt
|
||||
|
||||
@property
|
||||
def solver_dt_tensor(self):
|
||||
return self.solver.safety_rollout.dynamics_model.traj_dt[0]
|
||||
|
||||
def update_solver_dt(
|
||||
self,
|
||||
@@ -1254,3 +1310,79 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
for rollout in rollouts
|
||||
if isinstance(rollout, ArmReacher)
|
||||
]
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def jit_feasible_success(
|
||||
feasible,
|
||||
position_error: Union[torch.Tensor, None],
|
||||
rotation_error: Union[torch.Tensor, None],
|
||||
cspace_error: Union[torch.Tensor, None],
|
||||
position_threshold: float,
|
||||
rotation_threshold: float,
|
||||
cspace_threshold: float,
|
||||
):
|
||||
feasible = torch.all(feasible, dim=-1)
|
||||
converge = feasible
|
||||
if position_error is not None and rotation_error is not None:
|
||||
converge = torch.logical_and(
|
||||
position_error[..., -1] <= position_threshold,
|
||||
rotation_error[..., -1] <= rotation_threshold,
|
||||
)
|
||||
elif cspace_error is not None:
|
||||
converge = cspace_error[..., -1] <= cspace_threshold
|
||||
|
||||
success = torch.logical_and(feasible, converge)
|
||||
return success
|
||||
|
||||
|
||||
@get_torch_jit_decorator(only_valid_for_compile=True)
|
||||
def jit_trajopt_best_select(
|
||||
success,
|
||||
smooth_label,
|
||||
cspace_error: Union[torch.Tensor, None],
|
||||
pose_error: Union[torch.Tensor, None],
|
||||
position_error: Union[torch.Tensor, None],
|
||||
rotation_error: Union[torch.Tensor, None],
|
||||
goalset_index: Union[torch.Tensor, None],
|
||||
cost,
|
||||
smooth_cost,
|
||||
batch_mode: bool,
|
||||
batch: int,
|
||||
num_seeds: int,
|
||||
col,
|
||||
opt_dt,
|
||||
):
|
||||
success[~smooth_label] = False
|
||||
convergence_error = 0
|
||||
# get the best solution:
|
||||
if pose_error is not None:
|
||||
convergence_error = pose_error[..., -1]
|
||||
elif cspace_error is not None:
|
||||
convergence_error = cspace_error[..., -1]
|
||||
|
||||
running_cost = torch.mean(cost, dim=-1) * 0.0001
|
||||
error = convergence_error + smooth_cost + running_cost
|
||||
error[~success] += 10000.0
|
||||
if batch_mode:
|
||||
idx = torch.argmin(error.view(batch, num_seeds), dim=-1)
|
||||
idx = idx + num_seeds * col
|
||||
success = success[idx]
|
||||
else:
|
||||
idx = torch.argmin(error, dim=0)
|
||||
|
||||
success = success[idx : idx + 1]
|
||||
|
||||
# goalset_index = position_error = rotation_error = cspace_error = None
|
||||
if position_error is not None:
|
||||
position_error = position_error[idx, -1]
|
||||
if rotation_error is not None:
|
||||
rotation_error = rotation_error[idx, -1]
|
||||
if cspace_error is not None:
|
||||
cspace_error = cspace_error[idx, -1]
|
||||
if goalset_index is not None:
|
||||
goalset_index = goalset_index[idx, -1]
|
||||
|
||||
opt_dt = opt_dt[idx]
|
||||
|
||||
return idx, position_error, rotation_error, cspace_error, goalset_index, opt_dt, success
|
||||
|
||||
Reference in New Issue
Block a user