update to 0.6.2
This commit is contained in:
@@ -77,6 +77,7 @@ class TrajOptSolverConfig:
|
||||
use_cuda_graph_metrics: bool = False
|
||||
trim_steps: Optional[List[int]] = None
|
||||
store_debug_in_result: bool = False
|
||||
optimize_dt: bool = True
|
||||
|
||||
@staticmethod
|
||||
@profiler.record_function("trajopt_config/load_from_robot_config")
|
||||
@@ -107,7 +108,7 @@ class TrajOptSolverConfig:
|
||||
collision_checker_type: Optional[CollisionCheckerType] = CollisionCheckerType.PRIMITIVE,
|
||||
traj_evaluator_config: TrajEvaluatorConfig = TrajEvaluatorConfig(),
|
||||
traj_evaluator: Optional[TrajEvaluator] = None,
|
||||
minimize_jerk: Optional[bool] = None,
|
||||
minimize_jerk: bool = True,
|
||||
use_gradient_descent: bool = False,
|
||||
collision_cache: Optional[Dict[str, int]] = None,
|
||||
n_collision_envs: Optional[int] = None,
|
||||
@@ -126,7 +127,9 @@ class TrajOptSolverConfig:
|
||||
smooth_weight: Optional[List[float]] = None,
|
||||
state_finite_difference_mode: Optional[str] = None,
|
||||
filter_robot_command: bool = False,
|
||||
optimize_dt: bool = True,
|
||||
):
|
||||
# NOTE: Don't have default optimize_dt, instead read from a configuration file.
|
||||
# use default values, disable environment collision checking
|
||||
if isinstance(robot_cfg, str):
|
||||
robot_cfg = load_yaml(join_path(get_robot_configs_path(), robot_cfg))["robot_cfg"]
|
||||
@@ -199,6 +202,7 @@ class TrajOptSolverConfig:
|
||||
fixed_iters = True
|
||||
grad_config_data["lbfgs"]["store_debug"] = store_debug
|
||||
config_data["mppi"]["store_debug"] = store_debug
|
||||
store_debug_in_result = True
|
||||
|
||||
if use_cuda_graph is not None:
|
||||
config_data["mppi"]["use_cuda_graph"] = use_cuda_graph
|
||||
@@ -332,6 +336,7 @@ class TrajOptSolverConfig:
|
||||
use_cuda_graph_metrics=use_cuda_graph,
|
||||
trim_steps=trim_steps,
|
||||
store_debug_in_result=store_debug_in_result,
|
||||
optimize_dt=optimize_dt,
|
||||
)
|
||||
return trajopt_cfg
|
||||
|
||||
@@ -354,6 +359,7 @@ class TrajResult(Sequence):
|
||||
smooth_label: Optional[T_BValue_bool] = None
|
||||
optimized_dt: Optional[torch.Tensor] = None
|
||||
raw_solution: Optional[JointState] = None
|
||||
raw_action: Optional[torch.Tensor] = None
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# position_error = rotation_error = cspace_error = path_buffer_last_tstep = None
|
||||
@@ -392,17 +398,14 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
def __init__(self, config: TrajOptSolverConfig) -> None:
|
||||
super().__init__(**vars(config))
|
||||
self.dof = self.rollout_fn.d_action
|
||||
self.delta_vec = action_interpolate_kernel(2, self.traj_tsteps, self.tensor_args).unsqueeze(
|
||||
0
|
||||
)
|
||||
self.waypoint_delta_vec = interpolate_kernel(
|
||||
3, int(self.traj_tsteps / 2), self.tensor_args
|
||||
).unsqueeze(0)
|
||||
self.waypoint_delta_vec = torch.roll(self.waypoint_delta_vec, -1, dims=1)
|
||||
self.waypoint_delta_vec[:, -1, :] = self.waypoint_delta_vec[:, -2, :]
|
||||
assert self.traj_tsteps / 2 != 0.0
|
||||
self.solver.update_nenvs(self.num_seeds)
|
||||
self.action_horizon = self.rollout_fn.action_horizon
|
||||
self.delta_vec = interpolate_kernel(2, self.action_horizon, self.tensor_args).unsqueeze(0)
|
||||
|
||||
self.waypoint_delta_vec = interpolate_kernel(
|
||||
3, int(self.action_horizon / 2), self.tensor_args
|
||||
).unsqueeze(0)
|
||||
assert self.action_horizon / 2 != 0.0
|
||||
self.solver.update_nenvs(self.num_seeds)
|
||||
self._max_joint_vel = (
|
||||
self.solver.safety_rollout.state_bounds.velocity.view(2, self.dof)[1, :].reshape(
|
||||
1, 1, self.dof
|
||||
@@ -410,7 +413,6 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
) - 0.02
|
||||
self._max_joint_acc = self.rollout_fn.state_bounds.acceleration[1, :] - 0.02
|
||||
self._max_joint_jerk = self.rollout_fn.state_bounds.jerk[1, :] - 0.02
|
||||
# self._max_joint_jerk = self._max_joint_jerk * 0.0 + 10.0
|
||||
self._num_seeds = -1
|
||||
self._col = None
|
||||
if self.traj_evaluator is None:
|
||||
@@ -844,9 +846,9 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
result.metrics = self.interpolate_rollout.get_metrics(interpolated_trajs)
|
||||
|
||||
st_time = time.time()
|
||||
|
||||
feasible = torch.all(result.metrics.feasible, dim=-1)
|
||||
|
||||
# if self.num_seeds == 1:
|
||||
# print(result.metrics)
|
||||
if result.metrics.position_error is not None:
|
||||
converge = torch.logical_and(
|
||||
result.metrics.position_error[..., -1] <= self.position_threshold,
|
||||
@@ -874,6 +876,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
cspace_error=result.metrics.cspace_error,
|
||||
optimized_dt=opt_dt,
|
||||
raw_solution=result.action,
|
||||
raw_action=result.raw_action,
|
||||
)
|
||||
else:
|
||||
# get path length:
|
||||
@@ -904,7 +907,6 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
convergence_error = result.metrics.cspace_error[..., -1]
|
||||
else:
|
||||
raise ValueError("convergence check requires either goal_pose or goal_state")
|
||||
|
||||
error = convergence_error + smooth_cost
|
||||
error[~success] += 10000.0
|
||||
if batch_mode:
|
||||
@@ -919,6 +921,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
success = success[idx : idx + 1]
|
||||
|
||||
best_act_seq = result.action[idx]
|
||||
best_raw_action = result.raw_action[idx]
|
||||
interpolated_traj = interpolated_trajs[idx]
|
||||
position_error = rotation_error = cspace_error = None
|
||||
if result.metrics.position_error is not None:
|
||||
@@ -961,6 +964,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
smooth_label=smooth_label,
|
||||
optimized_dt=opt_dt,
|
||||
raw_solution=best_act_seq,
|
||||
raw_action=best_raw_action,
|
||||
)
|
||||
return traj_result
|
||||
|
||||
@@ -1044,7 +1048,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
if goal.goal_state is not None and self.use_cspace_seed:
|
||||
# get linear seed
|
||||
seed_traj = self.get_seeds(goal.current_state, goal.goal_state, num_seeds=num_seeds)
|
||||
# .view(batch_size, self.num_seeds, self.traj_tsteps, self.dof)
|
||||
# .view(batch_size, self.num_seeds, self.action_horizon, self.dof)
|
||||
else:
|
||||
# get start repeat seed:
|
||||
log_info("No goal state found, using current config to seed")
|
||||
@@ -1063,7 +1067,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
seed_traj = torch.cat((seed_traj, new_seeds), dim=0) # n_seed, batch, h, dof
|
||||
|
||||
seed_traj = seed_traj.view(
|
||||
total_seeds, self.traj_tsteps, self.dof
|
||||
total_seeds, self.action_horizon, self.dof
|
||||
) # n_seeds,batch, h, dof
|
||||
return seed_traj
|
||||
|
||||
@@ -1079,27 +1083,27 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
if n_seeds["linear"] > 0:
|
||||
linear_seed = self.get_linear_seed(start_state, goal_state)
|
||||
|
||||
linear_seeds = linear_seed.view(1, -1, self.traj_tsteps, self.dof).repeat(
|
||||
linear_seeds = linear_seed.view(1, -1, self.action_horizon, self.dof).repeat(
|
||||
1, n_seeds["linear"], 1, 1
|
||||
)
|
||||
seed_set.append(linear_seeds)
|
||||
if n_seeds["bias"] > 0:
|
||||
bias_seed = self.get_bias_seed(start_state, goal_state)
|
||||
bias_seeds = bias_seed.view(1, -1, self.traj_tsteps, self.dof).repeat(
|
||||
bias_seeds = bias_seed.view(1, -1, self.action_horizon, self.dof).repeat(
|
||||
1, n_seeds["bias"], 1, 1
|
||||
)
|
||||
seed_set.append(bias_seeds)
|
||||
if n_seeds["start"] > 0:
|
||||
bias_seed = self.get_start_seed(start_state)
|
||||
|
||||
bias_seeds = bias_seed.view(1, -1, self.traj_tsteps, self.dof).repeat(
|
||||
bias_seeds = bias_seed.view(1, -1, self.action_horizon, self.dof).repeat(
|
||||
1, n_seeds["start"], 1, 1
|
||||
)
|
||||
seed_set.append(bias_seeds)
|
||||
if n_seeds["goal"] > 0:
|
||||
bias_seed = self.get_start_seed(goal_state)
|
||||
|
||||
bias_seeds = bias_seed.view(1, -1, self.traj_tsteps, self.dof).repeat(
|
||||
bias_seeds = bias_seed.view(1, -1, self.action_horizon, self.dof).repeat(
|
||||
1, n_seeds["goal"], 1, 1
|
||||
)
|
||||
seed_set.append(bias_seeds)
|
||||
@@ -1142,6 +1146,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
tensor_args=self.tensor_args,
|
||||
out_traj_state=self._interpolated_traj_buffer,
|
||||
min_dt=self.traj_evaluator_config.min_dt,
|
||||
optimize_dt=self.optimize_dt,
|
||||
)
|
||||
|
||||
return state, last_tstep, opt_dt
|
||||
|
||||
Reference in New Issue
Block a user