update to 0.6.2
This commit is contained in:
@@ -122,7 +122,7 @@ class MotionGenConfig:
|
||||
interpolation_dt: float = 0.01
|
||||
|
||||
#: scale initial dt by this value to finetune trajectory optimization.
|
||||
finetune_dt_scale: float = 1.05
|
||||
finetune_dt_scale: float = 0.98
|
||||
|
||||
@staticmethod
|
||||
@profiler.record_function("motion_gen_config/load_from_robot_config")
|
||||
@@ -192,12 +192,13 @@ class MotionGenConfig:
|
||||
smooth_weight: List[float] = None,
|
||||
finetune_smooth_weight: Optional[List[float]] = None,
|
||||
state_finite_difference_mode: Optional[str] = None,
|
||||
finetune_dt_scale: float = 1.05,
|
||||
finetune_dt_scale: float = 0.98,
|
||||
maximum_trajectory_time: Optional[float] = None,
|
||||
maximum_trajectory_dt: float = 0.1,
|
||||
velocity_scale: Optional[Union[List[float], float]] = None,
|
||||
acceleration_scale: Optional[Union[List[float], float]] = None,
|
||||
jerk_scale: Optional[Union[List[float], float]] = None,
|
||||
optimize_dt: bool = True,
|
||||
):
|
||||
"""Load motion generation configuration from robot and world configurations.
|
||||
|
||||
@@ -279,7 +280,11 @@ class MotionGenConfig:
|
||||
"""
|
||||
|
||||
init_warp(tensor_args=tensor_args)
|
||||
|
||||
if js_trajopt_tsteps is not None:
|
||||
log_warn("js_trajopt_tsteps is deprecated, use trajopt_tsteps instead.")
|
||||
trajopt_tsteps = js_trajopt_tsteps
|
||||
if trajopt_tsteps is not None:
|
||||
js_trajopt_tsteps = trajopt_tsteps
|
||||
if velocity_scale is not None and isinstance(velocity_scale, float):
|
||||
velocity_scale = [velocity_scale]
|
||||
|
||||
@@ -318,6 +323,8 @@ class MotionGenConfig:
|
||||
|
||||
if isinstance(robot_cfg, str):
|
||||
robot_cfg = load_yaml(join_path(get_robot_configs_path(), robot_cfg))["robot_cfg"]
|
||||
elif isinstance(robot_cfg, Dict) and "robot_cfg" in robot_cfg.keys():
|
||||
robot_cfg = robot_cfg["robot_cfg"]
|
||||
if isinstance(robot_cfg, RobotConfig):
|
||||
if ee_link_name is not None:
|
||||
log_error("ee link cannot be changed after creating RobotConfig")
|
||||
@@ -445,6 +452,7 @@ class MotionGenConfig:
|
||||
state_finite_difference_mode=state_finite_difference_mode,
|
||||
filter_robot_command=filter_robot_command,
|
||||
minimize_jerk=minimize_jerk,
|
||||
optimize_dt=optimize_dt,
|
||||
)
|
||||
trajopt_solver = TrajOptSolver(trajopt_cfg)
|
||||
|
||||
@@ -465,8 +473,6 @@ class MotionGenConfig:
|
||||
self_collision_check=self_collision_check,
|
||||
self_collision_opt=self_collision_opt,
|
||||
grad_trajopt_iters=grad_trajopt_iters,
|
||||
# num_seeds=num_trajopt_noisy_seeds,
|
||||
# seed_ratio=trajopt_seed_ratio,
|
||||
interpolation_dt=interpolation_dt,
|
||||
use_particle_opt=trajopt_particle_opt,
|
||||
traj_evaluator_config=traj_evaluator_config,
|
||||
@@ -486,6 +492,7 @@ class MotionGenConfig:
|
||||
state_finite_difference_mode=state_finite_difference_mode,
|
||||
minimize_jerk=minimize_jerk,
|
||||
filter_robot_command=filter_robot_command,
|
||||
optimize_dt=optimize_dt,
|
||||
)
|
||||
js_trajopt_solver = TrajOptSolver(js_trajopt_cfg)
|
||||
|
||||
@@ -523,6 +530,7 @@ class MotionGenConfig:
|
||||
trim_steps=trim_steps,
|
||||
use_gradient_descent=use_gradient_descent,
|
||||
filter_robot_command=filter_robot_command,
|
||||
optimize_dt=optimize_dt,
|
||||
)
|
||||
|
||||
finetune_trajopt_solver = TrajOptSolver(finetune_trajopt_cfg)
|
||||
@@ -748,7 +756,9 @@ class MotionGenResult:
|
||||
|
||||
@property
|
||||
def motion_time(self):
|
||||
return self.optimized_dt * (self.optimized_plan.position.shape[-2] - 1)
|
||||
# -2 as last three timesteps have the same value
|
||||
# 0, 1 also have the same position value.
|
||||
return self.optimized_dt * (self.optimized_plan.position.shape[-2] - 1 - 2 - 1)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -762,7 +772,7 @@ class MotionGenPlanConfig:
|
||||
enable_graph_attempt: Optional[int] = 3
|
||||
disable_graph_attempt: Optional[int] = None
|
||||
ik_fail_return: Optional[int] = None
|
||||
partial_ik_opt: bool = True
|
||||
partial_ik_opt: bool = False
|
||||
num_ik_seeds: Optional[int] = None
|
||||
num_graph_seeds: Optional[int] = None
|
||||
num_trajopt_seeds: Optional[int] = None
|
||||
@@ -770,6 +780,7 @@ class MotionGenPlanConfig:
|
||||
fail_on_invalid_query: bool = True
|
||||
#: enables retiming trajectory optimization, useful for getting low jerk trajectories.
|
||||
enable_finetune_trajopt: bool = True
|
||||
parallel_finetune: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.enable_opt and not self.enable_graph:
|
||||
@@ -1103,8 +1114,11 @@ class MotionGen(MotionGenConfig):
|
||||
if plan_config.enable_graph:
|
||||
raise ValueError("Graph Search / Geometric Planner not supported in batch_env mode")
|
||||
|
||||
if plan_config.enable_graph:
|
||||
log_info("Batch mode enable graph is only supported with num_graph_seeds==1")
|
||||
if plan_config.enable_graph or (
|
||||
plan_config.enable_graph_attempt is not None
|
||||
and plan_config.max_attempts >= plan_config.enable_graph_attempt
|
||||
):
|
||||
log_warn("Batch mode enable graph is only supported with num_graph_seeds==1")
|
||||
plan_config.num_trajopt_seeds = 1
|
||||
plan_config.num_graph_seeds = 1
|
||||
solve_state.num_trajopt_seeds = 1
|
||||
@@ -1316,7 +1330,6 @@ class MotionGen(MotionGenConfig):
|
||||
trajopt_seed_traj = None
|
||||
trajopt_seed_success = None
|
||||
trajopt_newton_iters = None
|
||||
|
||||
graph_success = 0
|
||||
if len(start_state.shape) == 1:
|
||||
log_error("Joint state should be not a vector (dof) should be (bxdof)")
|
||||
@@ -1330,6 +1343,7 @@ class MotionGen(MotionGenConfig):
|
||||
plan_config.partial_ik_opt,
|
||||
link_poses,
|
||||
)
|
||||
|
||||
if not plan_config.enable_graph and plan_config.partial_ik_opt:
|
||||
ik_result.success[:] = True
|
||||
|
||||
@@ -1364,7 +1378,7 @@ class MotionGen(MotionGenConfig):
|
||||
if plan_config.enable_graph:
|
||||
interpolation_steps = None
|
||||
if plan_config.enable_opt:
|
||||
interpolation_steps = self.trajopt_solver.traj_tsteps - 4
|
||||
interpolation_steps = self.trajopt_solver.action_horizon
|
||||
log_info("MG: running GP")
|
||||
graph_result = self.graph_search(start_config, goal_config, interpolation_steps)
|
||||
trajopt_seed_success = graph_result.success
|
||||
@@ -1378,6 +1392,8 @@ class MotionGen(MotionGenConfig):
|
||||
|
||||
result.used_graph = True
|
||||
if plan_config.enable_opt:
|
||||
# print(result.graph_plan.position.shape, interpolation_steps,
|
||||
# graph_result.path_buffer_last_tstep)
|
||||
trajopt_seed = (
|
||||
result.graph_plan.position.view(
|
||||
1, # solve_state.batch_size,
|
||||
@@ -1389,12 +1405,11 @@ class MotionGen(MotionGenConfig):
|
||||
.contiguous()
|
||||
)
|
||||
trajopt_seed_traj = torch.zeros(
|
||||
(trajopt_seed.shape[0], 1, self.trajopt_solver.traj_tsteps, self._dof),
|
||||
(trajopt_seed.shape[0], 1, self.trajopt_solver.action_horizon, self._dof),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
)
|
||||
trajopt_seed_traj[:, :, :-4, :] = trajopt_seed
|
||||
trajopt_seed_traj[:, :, -4:, :] = trajopt_seed_traj[:, :, -5:-4, :]
|
||||
trajopt_seed_traj[:, :, :interpolation_steps, :] = trajopt_seed
|
||||
trajopt_seed_success = ik_result.success.clone()
|
||||
trajopt_seed_success[ik_result.success] = graph_result.success
|
||||
|
||||
@@ -1497,7 +1512,7 @@ class MotionGen(MotionGenConfig):
|
||||
trajopt_seed_traj = trajopt_seed_traj.view(
|
||||
solve_state.num_trajopt_seeds * self.noisy_trajopt_seeds,
|
||||
solve_state.batch_size,
|
||||
self.trajopt_solver.traj_tsteps,
|
||||
self.trajopt_solver.action_horizon,
|
||||
self._dof,
|
||||
).contiguous()
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
@@ -1511,31 +1526,27 @@ class MotionGen(MotionGenConfig):
|
||||
trajopt_seed_traj,
|
||||
num_seeds_override=solve_state.num_trajopt_seeds * self.noisy_trajopt_seeds,
|
||||
newton_iters=trajopt_newton_iters,
|
||||
return_all_solutions=plan_config.parallel_finetune,
|
||||
)
|
||||
if False and not traj_result.success.item():
|
||||
# pose_convergence = traj_result.position_error < self.
|
||||
print(
|
||||
traj_result.position_error.item(),
|
||||
traj_result.rotation_error.item(),
|
||||
torch.count_nonzero(~traj_result.metrics.feasible[0]).item(),
|
||||
torch.count_nonzero(~traj_result.metrics.feasible[1]).item(),
|
||||
traj_result.optimized_dt.item(),
|
||||
)
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
self.trajopt_solver.interpolation_type = og_value
|
||||
# self.trajopt_solver.compute_metrics(not og_evaluate, og_evaluate)
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["trajopt_result"] = traj_result
|
||||
# run finetune
|
||||
if plan_config.enable_finetune_trajopt and traj_result.success[0].item():
|
||||
if plan_config.enable_finetune_trajopt and torch.count_nonzero(traj_result.success) > 0:
|
||||
with profiler.record_function("motion_gen/finetune_trajopt"):
|
||||
seed_traj = traj_result.solution.position.clone()
|
||||
seed_traj = torch.roll(seed_traj, -2, dims=-2)
|
||||
# seed_traj[..., -2:, :] = seed_traj[..., -3, :]
|
||||
seed_traj = traj_result.raw_action.clone() # solution.position.clone()
|
||||
seed_traj = seed_traj.contiguous()
|
||||
og_solve_time = traj_result.solve_time
|
||||
seed_override = 1
|
||||
opt_dt = traj_result.optimized_dt
|
||||
|
||||
if plan_config.parallel_finetune:
|
||||
opt_dt = torch.min(opt_dt[traj_result.success])
|
||||
seed_override = solve_state.num_trajopt_seeds * self.noisy_trajopt_seeds
|
||||
scaled_dt = torch.clamp(
|
||||
traj_result.optimized_dt * self.finetune_dt_scale,
|
||||
opt_dt * self.finetune_dt_scale,
|
||||
self.trajopt_solver.interpolation_dt,
|
||||
)
|
||||
self.finetune_trajopt_solver.update_solver_dt(scaled_dt.item())
|
||||
@@ -1545,26 +1556,16 @@ class MotionGen(MotionGenConfig):
|
||||
solve_state,
|
||||
seed_traj,
|
||||
trajopt_instance=self.finetune_trajopt_solver,
|
||||
num_seeds_override=1,
|
||||
num_seeds_override=seed_override,
|
||||
)
|
||||
if False and not traj_result.success.item():
|
||||
print(
|
||||
traj_result.position_error.item(),
|
||||
traj_result.rotation_error.item(),
|
||||
torch.count_nonzero(~traj_result.metrics.feasible).item(),
|
||||
traj_result.optimized_dt.item(),
|
||||
)
|
||||
# if not traj_result.success.item():
|
||||
# #print(traj_result.metrics.constraint)
|
||||
# print(traj_result.position_error.item() * 100.0,
|
||||
# traj_result.rotation_error.item() * 100.0)
|
||||
|
||||
result.finetune_time = traj_result.solve_time
|
||||
|
||||
traj_result.solve_time = og_solve_time
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["finetune_trajopt_result"] = traj_result
|
||||
|
||||
elif plan_config.enable_finetune_trajopt:
|
||||
traj_result.success = traj_result.success[0:1]
|
||||
result.solve_time += traj_result.solve_time + result.finetune_time
|
||||
result.trajopt_time = traj_result.solve_time
|
||||
result.trajopt_attempts = 1
|
||||
@@ -1576,12 +1577,220 @@ class MotionGen(MotionGenConfig):
|
||||
result.interpolated_plan = traj_result.interpolated_solution.trim_trajectory(
|
||||
0, traj_result.path_buffer_last_tstep[0]
|
||||
)
|
||||
# print(ik_result.position_error[ik_result.success] * 1000.0)
|
||||
# print(traj_result.position_error * 1000.0)
|
||||
# exit()
|
||||
result.interpolation_dt = self.trajopt_solver.interpolation_dt
|
||||
result.path_buffer_last_tstep = traj_result.path_buffer_last_tstep
|
||||
result.position_error = traj_result.position_error
|
||||
result.rotation_error = traj_result.rotation_error
|
||||
result.optimized_dt = traj_result.optimized_dt
|
||||
result.optimized_plan = traj_result.solution
|
||||
return result
|
||||
|
||||
def _plan_js_from_solve_state(
|
||||
self,
|
||||
solve_state: ReacherSolveState,
|
||||
start_state: JointState,
|
||||
goal_state: JointState,
|
||||
plan_config: MotionGenPlanConfig = MotionGenPlanConfig(),
|
||||
) -> MotionGenResult:
|
||||
trajopt_seed_traj = None
|
||||
trajopt_seed_success = None
|
||||
trajopt_newton_iters = None
|
||||
|
||||
graph_success = 0
|
||||
if len(start_state.shape) == 1:
|
||||
log_error("Joint state should be not a vector (dof) should be (bxdof)")
|
||||
|
||||
result = MotionGenResult(cspace_error=torch.zeros((1), device=self.tensor_args.device))
|
||||
# do graph search:
|
||||
if plan_config.enable_graph:
|
||||
start_config = torch.zeros(
|
||||
(solve_state.num_graph_seeds, self.js_trajopt_solver.dof),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
)
|
||||
goal_config = start_config.clone()
|
||||
start_config[:] = start_state.position
|
||||
goal_config[:] = goal_state.position
|
||||
interpolation_steps = None
|
||||
if plan_config.enable_opt:
|
||||
interpolation_steps = self.js_trajopt_solver.action_horizon
|
||||
log_info("MG: running GP")
|
||||
graph_result = self.graph_search(start_config, goal_config, interpolation_steps)
|
||||
trajopt_seed_success = graph_result.success
|
||||
|
||||
graph_success = torch.count_nonzero(graph_result.success).item()
|
||||
result.graph_time = graph_result.solve_time
|
||||
result.solve_time += graph_result.solve_time
|
||||
if graph_success > 0:
|
||||
result.graph_plan = graph_result.interpolated_plan
|
||||
result.interpolated_plan = graph_result.interpolated_plan
|
||||
|
||||
result.used_graph = True
|
||||
if plan_config.enable_opt:
|
||||
trajopt_seed = (
|
||||
result.graph_plan.position.view(
|
||||
1, # solve_state.batch_size,
|
||||
graph_success, # solve_state.num_trajopt_seeds,
|
||||
interpolation_steps,
|
||||
self._dof,
|
||||
)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
trajopt_seed_traj = torch.zeros(
|
||||
(trajopt_seed.shape[0], 1, self.trajopt_solver.action_horizon, self._dof),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
)
|
||||
trajopt_seed_traj[:, :, :interpolation_steps, :] = trajopt_seed
|
||||
trajopt_seed_success = graph_result.success
|
||||
|
||||
trajopt_seed_success = trajopt_seed_success.view(
|
||||
1, solve_state.num_trajopt_seeds
|
||||
)
|
||||
trajopt_newton_iters = self.graph_trajopt_iters
|
||||
else:
|
||||
_, idx = torch.topk(
|
||||
graph_result.path_length[graph_result.success], k=1, largest=False
|
||||
)
|
||||
result.interpolated_plan = result.interpolated_plan[idx].squeeze(0)
|
||||
result.optimized_dt = self.tensor_args.to_device(self.interpolation_dt)
|
||||
result.optimized_plan = result.interpolated_plan[
|
||||
: graph_result.path_buffer_last_tstep[idx.item()]
|
||||
]
|
||||
idx = idx.view(-1) + self._batch_col
|
||||
result.cspace_error = torch.zeros((1), device=self.tensor_args.device)
|
||||
|
||||
result.path_buffer_last_tstep = graph_result.path_buffer_last_tstep[
|
||||
idx.item() : idx.item() + 1
|
||||
]
|
||||
result.success = torch.as_tensor([True], device=self.tensor_args.device)
|
||||
return result
|
||||
else:
|
||||
result.success = torch.as_tensor([False], device=self.tensor_args.device)
|
||||
result.status = "Graph Fail"
|
||||
if not graph_result.valid_query:
|
||||
result.valid_query = False
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["graph_debug"] = graph_result.debug_info
|
||||
return result
|
||||
if plan_config.need_graph_success:
|
||||
return result
|
||||
|
||||
# do trajopt:
|
||||
if plan_config.enable_opt:
|
||||
with profiler.record_function("motion_gen/setup_trajopt_seeds"):
|
||||
# self._trajopt_goal_config[:, :ik_success] = goal_config
|
||||
|
||||
goal = Goal(
|
||||
current_state=start_state,
|
||||
goal_state=goal_state,
|
||||
)
|
||||
|
||||
if trajopt_seed_traj is None or graph_success < solve_state.num_trajopt_seeds * 1:
|
||||
seed_goal = Goal(
|
||||
current_state=start_state.repeat_seeds(solve_state.num_trajopt_seeds),
|
||||
goal_state=goal_state.repeat_seeds(solve_state.num_trajopt_seeds),
|
||||
)
|
||||
if trajopt_seed_traj is not None:
|
||||
trajopt_seed_traj = trajopt_seed_traj.transpose(0, 1).contiguous()
|
||||
# batch, num_seeds, h, dof
|
||||
if trajopt_seed_success.shape[1] < self.js_trajopt_solver.num_seeds:
|
||||
trajopt_seed_success_new = torch.zeros(
|
||||
(1, solve_state.num_trajopt_seeds),
|
||||
device=self.tensor_args.device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
trajopt_seed_success_new[
|
||||
0, : trajopt_seed_success.shape[1]
|
||||
] = trajopt_seed_success
|
||||
trajopt_seed_success = trajopt_seed_success_new
|
||||
# create seeds here:
|
||||
trajopt_seed_traj = self.js_trajopt_solver.get_seed_set(
|
||||
seed_goal,
|
||||
trajopt_seed_traj, # batch, num_seeds, h, dof
|
||||
num_seeds=1,
|
||||
batch_mode=False,
|
||||
seed_success=trajopt_seed_success,
|
||||
)
|
||||
trajopt_seed_traj = trajopt_seed_traj.view(
|
||||
self.js_trajopt_solver.num_seeds * 1,
|
||||
1,
|
||||
self.trajopt_solver.action_horizon,
|
||||
self._dof,
|
||||
).contiguous()
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
og_value = self.trajopt_solver.interpolation_type
|
||||
self.js_trajopt_solver.interpolation_type = InterpolateType.LINEAR_CUDA
|
||||
with profiler.record_function("motion_gen/trajopt"):
|
||||
log_info("MG: running TO")
|
||||
traj_result = self._solve_trajopt_from_solve_state(
|
||||
goal,
|
||||
solve_state,
|
||||
trajopt_seed_traj,
|
||||
num_seeds_override=solve_state.num_trajopt_seeds * 1,
|
||||
newton_iters=trajopt_newton_iters,
|
||||
return_all_solutions=plan_config.enable_finetune_trajopt,
|
||||
trajopt_instance=self.js_trajopt_solver,
|
||||
)
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
self.trajopt_solver.interpolation_type = og_value
|
||||
# self.trajopt_solver.compute_metrics(not og_evaluate, og_evaluate)
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["trajopt_result"] = traj_result
|
||||
if torch.count_nonzero(traj_result.success) == 0:
|
||||
result.status = "TrajOpt Fail"
|
||||
# run finetune
|
||||
if plan_config.enable_finetune_trajopt and torch.count_nonzero(traj_result.success) > 0:
|
||||
with profiler.record_function("motion_gen/finetune_trajopt"):
|
||||
seed_traj = traj_result.raw_action.clone() # solution.position.clone()
|
||||
seed_traj = seed_traj.contiguous()
|
||||
og_solve_time = traj_result.solve_time
|
||||
|
||||
scaled_dt = torch.clamp(
|
||||
torch.max(traj_result.optimized_dt[traj_result.success]),
|
||||
self.trajopt_solver.interpolation_dt,
|
||||
)
|
||||
og_dt = self.js_trajopt_solver.solver_dt
|
||||
self.js_trajopt_solver.update_solver_dt(scaled_dt.item())
|
||||
|
||||
traj_result = self._solve_trajopt_from_solve_state(
|
||||
goal,
|
||||
solve_state,
|
||||
seed_traj,
|
||||
trajopt_instance=self.js_trajopt_solver,
|
||||
num_seeds_override=solve_state.num_trajopt_seeds * self.noisy_trajopt_seeds,
|
||||
)
|
||||
self.js_trajopt_solver.update_solver_dt(og_dt)
|
||||
|
||||
result.finetune_time = traj_result.solve_time
|
||||
|
||||
traj_result.solve_time = og_solve_time
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["finetune_trajopt_result"] = traj_result
|
||||
if torch.count_nonzero(traj_result.success) == 0:
|
||||
result.status = "Finetune Fail"
|
||||
elif plan_config.enable_finetune_trajopt:
|
||||
traj_result.success = traj_result.success[0:1]
|
||||
result.solve_time += traj_result.solve_time + result.finetune_time
|
||||
result.trajopt_time = traj_result.solve_time
|
||||
result.trajopt_attempts = 1
|
||||
result.success = traj_result.success
|
||||
|
||||
result.interpolated_plan = traj_result.interpolated_solution.trim_trajectory(
|
||||
0, traj_result.path_buffer_last_tstep[0]
|
||||
)
|
||||
# print(ik_result.position_error[ik_result.success] * 1000.0)
|
||||
# print(traj_result.position_error * 1000.0)
|
||||
# exit()
|
||||
result.interpolation_dt = self.trajopt_solver.interpolation_dt
|
||||
result.path_buffer_last_tstep = traj_result.path_buffer_last_tstep
|
||||
result.cspace_error = traj_result.cspace_error
|
||||
result.optimized_dt = traj_result.optimized_dt
|
||||
result.optimized_plan = traj_result.solution
|
||||
|
||||
return result
|
||||
|
||||
@@ -1644,7 +1853,7 @@ class MotionGen(MotionGenConfig):
|
||||
if plan_config.enable_graph:
|
||||
interpolation_steps = None
|
||||
if plan_config.enable_opt:
|
||||
interpolation_steps = self.trajopt_solver.traj_tsteps - 4
|
||||
interpolation_steps = self.trajopt_solver.action_horizon
|
||||
|
||||
start_graph_state = start_state.repeat_seeds(ik_out_seeds)
|
||||
start_config = start_graph_state.position[ik_result.success.view(-1)].view(
|
||||
@@ -1662,23 +1871,17 @@ class MotionGen(MotionGenConfig):
|
||||
result.used_graph = True
|
||||
|
||||
if plan_config.enable_opt:
|
||||
trajopt_seed = (
|
||||
result.graph_plan.position.view(
|
||||
1, # solve_state.batch_size,
|
||||
graph_success, # solve_state.num_trajopt_seeds,
|
||||
interpolation_steps,
|
||||
self._dof,
|
||||
)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
trajopt_seed = result.graph_plan.position.view(
|
||||
graph_success, # solve_state.num_trajopt_seeds,
|
||||
interpolation_steps,
|
||||
self._dof,
|
||||
).contiguous()
|
||||
trajopt_seed_traj = torch.zeros(
|
||||
(trajopt_seed.shape[0], 1, self.trajopt_solver.traj_tsteps, self._dof),
|
||||
(1, trajopt_seed.shape[0], self.trajopt_solver.action_horizon, self._dof),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
)
|
||||
trajopt_seed_traj[:, :, :-4, :] = trajopt_seed
|
||||
trajopt_seed_traj[:, :, -4:, :] = trajopt_seed_traj[:, :, -5:-4, :]
|
||||
trajopt_seed_traj[0, :, :interpolation_steps, :] = trajopt_seed
|
||||
|
||||
trajopt_seed_success = ik_result.success.clone()
|
||||
trajopt_seed_success[ik_result.success] = graph_result.success
|
||||
@@ -1766,14 +1969,54 @@ class MotionGen(MotionGenConfig):
|
||||
trajopt_seed_traj = trajopt_seed_traj.view(
|
||||
solve_state.num_trajopt_seeds,
|
||||
solve_state.batch_size,
|
||||
self.trajopt_solver.traj_tsteps,
|
||||
self.trajopt_solver.action_horizon,
|
||||
self._dof,
|
||||
).contiguous()
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
og_value = self.trajopt_solver.interpolation_type
|
||||
self.trajopt_solver.interpolation_type = InterpolateType.LINEAR_CUDA
|
||||
traj_result = self._solve_trajopt_from_solve_state(
|
||||
goal, solve_state, trajopt_seed_traj, newton_iters=trajopt_newton_iters
|
||||
goal,
|
||||
solve_state,
|
||||
trajopt_seed_traj,
|
||||
newton_iters=trajopt_newton_iters,
|
||||
return_all_solutions=True,
|
||||
)
|
||||
|
||||
# output of traj result will have 1 solution per batch
|
||||
|
||||
# run finetune opt on 1 solution per batch:
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
self.trajopt_solver.interpolation_type = og_value
|
||||
# self.trajopt_solver.compute_metrics(not og_evaluate, og_evaluate)
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["trajopt_result"] = traj_result
|
||||
# run finetune
|
||||
if plan_config.enable_finetune_trajopt and torch.count_nonzero(traj_result.success) > 0:
|
||||
with profiler.record_function("motion_gen/finetune_trajopt"):
|
||||
seed_traj = traj_result.raw_action.clone() # solution.position.clone()
|
||||
seed_traj = seed_traj.contiguous()
|
||||
og_solve_time = traj_result.solve_time
|
||||
|
||||
scaled_dt = torch.clamp(
|
||||
torch.max(traj_result.optimized_dt[traj_result.success])
|
||||
* self.finetune_dt_scale,
|
||||
self.trajopt_solver.interpolation_dt,
|
||||
)
|
||||
self.finetune_trajopt_solver.update_solver_dt(scaled_dt.item())
|
||||
traj_result = self._solve_trajopt_from_solve_state(
|
||||
goal,
|
||||
solve_state,
|
||||
seed_traj,
|
||||
trajopt_instance=self.finetune_trajopt_solver,
|
||||
num_seeds_override=solve_state.num_trajopt_seeds,
|
||||
)
|
||||
|
||||
result.finetune_time = traj_result.solve_time
|
||||
|
||||
traj_result.solve_time = og_solve_time
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["finetune_trajopt_result"] = traj_result
|
||||
result.success = traj_result.success
|
||||
|
||||
result.interpolated_plan = traj_result.interpolated_solution
|
||||
@@ -1837,6 +2080,7 @@ class MotionGen(MotionGenConfig):
|
||||
batch: Optional[int] = None,
|
||||
warmup_js_trajopt: bool = True,
|
||||
batch_env_mode: bool = False,
|
||||
parallel_finetune: bool = False,
|
||||
):
|
||||
log_info("Warmup")
|
||||
|
||||
@@ -1854,7 +2098,11 @@ class MotionGen(MotionGenConfig):
|
||||
self.plan_single(
|
||||
start_state,
|
||||
retract_pose,
|
||||
MotionGenPlanConfig(max_attempts=1, enable_finetune_trajopt=True),
|
||||
MotionGenPlanConfig(
|
||||
max_attempts=1,
|
||||
enable_finetune_trajopt=True,
|
||||
parallel_finetune=parallel_finetune,
|
||||
),
|
||||
link_poses=link_poses,
|
||||
)
|
||||
if enable_graph:
|
||||
@@ -1867,7 +2115,10 @@ class MotionGen(MotionGenConfig):
|
||||
start_state,
|
||||
retract_pose,
|
||||
MotionGenPlanConfig(
|
||||
max_attempts=1, enable_finetune_trajopt=True, enable_graph=enable_graph
|
||||
max_attempts=1,
|
||||
enable_finetune_trajopt=True,
|
||||
enable_graph=enable_graph,
|
||||
parallel_finetune=parallel_finetune,
|
||||
),
|
||||
link_poses=link_poses,
|
||||
)
|
||||
@@ -1925,14 +2176,24 @@ class MotionGen(MotionGenConfig):
|
||||
}
|
||||
result = None
|
||||
goal = Goal(goal_state=goal_state, current_state=start_state)
|
||||
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.SINGLE,
|
||||
num_ik_seeds=1,
|
||||
num_trajopt_seeds=self.js_trajopt_solver.num_seeds,
|
||||
num_graph_seeds=self.js_trajopt_solver.num_seeds,
|
||||
batch_size=1,
|
||||
n_envs=1,
|
||||
n_goalset=1,
|
||||
)
|
||||
for n in range(plan_config.max_attempts):
|
||||
traj_result = self.js_trajopt_solver.solve_single(goal)
|
||||
traj_result = self._plan_js_from_solve_state(
|
||||
solve_state, start_state, goal_state, plan_config=plan_config
|
||||
)
|
||||
time_dict["trajopt_time"] += traj_result.solve_time
|
||||
time_dict["trajopt_attempts"] = n
|
||||
|
||||
if result is None:
|
||||
result = MotionGenResult(success=traj_result.success)
|
||||
result = traj_result
|
||||
|
||||
if traj_result.success.item():
|
||||
break
|
||||
@@ -1940,25 +2201,7 @@ class MotionGen(MotionGenConfig):
|
||||
result.solve_time = time_dict["trajopt_time"]
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info = {"trajopt_result": traj_result}
|
||||
status = None
|
||||
if not traj_result.success.item():
|
||||
# print(traj_result.cspace_error, traj_result.success)
|
||||
status = ""
|
||||
if traj_result.cspace_error.item() >= self.js_trajopt_solver.cspace_threshold:
|
||||
status += " Fail: C-SPACE Convergence"
|
||||
if torch.count_nonzero(~traj_result.metrics.feasible).item() > 0:
|
||||
status += " Fail: Constraints"
|
||||
# print(traj_result.metrics.feasible)
|
||||
|
||||
result.status = status
|
||||
result.position_error = traj_result.position_error
|
||||
result.rotation_error = traj_result.rotation_error
|
||||
result.cspace_error = traj_result.cspace_error
|
||||
result.optimized_dt = traj_result.optimized_dt
|
||||
result.interpolated_plan = traj_result.interpolated_solution
|
||||
result.optimized_plan = traj_result.solution
|
||||
result.path_buffer_last_tstep = traj_result.path_buffer_last_tstep
|
||||
result.success = traj_result.success
|
||||
return result
|
||||
|
||||
def plan(
|
||||
|
||||
@@ -77,6 +77,7 @@ class TrajOptSolverConfig:
|
||||
use_cuda_graph_metrics: bool = False
|
||||
trim_steps: Optional[List[int]] = None
|
||||
store_debug_in_result: bool = False
|
||||
optimize_dt: bool = True
|
||||
|
||||
@staticmethod
|
||||
@profiler.record_function("trajopt_config/load_from_robot_config")
|
||||
@@ -107,7 +108,7 @@ class TrajOptSolverConfig:
|
||||
collision_checker_type: Optional[CollisionCheckerType] = CollisionCheckerType.PRIMITIVE,
|
||||
traj_evaluator_config: TrajEvaluatorConfig = TrajEvaluatorConfig(),
|
||||
traj_evaluator: Optional[TrajEvaluator] = None,
|
||||
minimize_jerk: Optional[bool] = None,
|
||||
minimize_jerk: bool = True,
|
||||
use_gradient_descent: bool = False,
|
||||
collision_cache: Optional[Dict[str, int]] = None,
|
||||
n_collision_envs: Optional[int] = None,
|
||||
@@ -126,7 +127,9 @@ class TrajOptSolverConfig:
|
||||
smooth_weight: Optional[List[float]] = None,
|
||||
state_finite_difference_mode: Optional[str] = None,
|
||||
filter_robot_command: bool = False,
|
||||
optimize_dt: bool = True,
|
||||
):
|
||||
# NOTE: Don't have default optimize_dt, instead read from a configuration file.
|
||||
# use default values, disable environment collision checking
|
||||
if isinstance(robot_cfg, str):
|
||||
robot_cfg = load_yaml(join_path(get_robot_configs_path(), robot_cfg))["robot_cfg"]
|
||||
@@ -199,6 +202,7 @@ class TrajOptSolverConfig:
|
||||
fixed_iters = True
|
||||
grad_config_data["lbfgs"]["store_debug"] = store_debug
|
||||
config_data["mppi"]["store_debug"] = store_debug
|
||||
store_debug_in_result = True
|
||||
|
||||
if use_cuda_graph is not None:
|
||||
config_data["mppi"]["use_cuda_graph"] = use_cuda_graph
|
||||
@@ -332,6 +336,7 @@ class TrajOptSolverConfig:
|
||||
use_cuda_graph_metrics=use_cuda_graph,
|
||||
trim_steps=trim_steps,
|
||||
store_debug_in_result=store_debug_in_result,
|
||||
optimize_dt=optimize_dt,
|
||||
)
|
||||
return trajopt_cfg
|
||||
|
||||
@@ -354,6 +359,7 @@ class TrajResult(Sequence):
|
||||
smooth_label: Optional[T_BValue_bool] = None
|
||||
optimized_dt: Optional[torch.Tensor] = None
|
||||
raw_solution: Optional[JointState] = None
|
||||
raw_action: Optional[torch.Tensor] = None
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# position_error = rotation_error = cspace_error = path_buffer_last_tstep = None
|
||||
@@ -392,17 +398,14 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
def __init__(self, config: TrajOptSolverConfig) -> None:
|
||||
super().__init__(**vars(config))
|
||||
self.dof = self.rollout_fn.d_action
|
||||
self.delta_vec = action_interpolate_kernel(2, self.traj_tsteps, self.tensor_args).unsqueeze(
|
||||
0
|
||||
)
|
||||
self.waypoint_delta_vec = interpolate_kernel(
|
||||
3, int(self.traj_tsteps / 2), self.tensor_args
|
||||
).unsqueeze(0)
|
||||
self.waypoint_delta_vec = torch.roll(self.waypoint_delta_vec, -1, dims=1)
|
||||
self.waypoint_delta_vec[:, -1, :] = self.waypoint_delta_vec[:, -2, :]
|
||||
assert self.traj_tsteps / 2 != 0.0
|
||||
self.solver.update_nenvs(self.num_seeds)
|
||||
self.action_horizon = self.rollout_fn.action_horizon
|
||||
self.delta_vec = interpolate_kernel(2, self.action_horizon, self.tensor_args).unsqueeze(0)
|
||||
|
||||
self.waypoint_delta_vec = interpolate_kernel(
|
||||
3, int(self.action_horizon / 2), self.tensor_args
|
||||
).unsqueeze(0)
|
||||
assert self.action_horizon / 2 != 0.0
|
||||
self.solver.update_nenvs(self.num_seeds)
|
||||
self._max_joint_vel = (
|
||||
self.solver.safety_rollout.state_bounds.velocity.view(2, self.dof)[1, :].reshape(
|
||||
1, 1, self.dof
|
||||
@@ -410,7 +413,6 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
) - 0.02
|
||||
self._max_joint_acc = self.rollout_fn.state_bounds.acceleration[1, :] - 0.02
|
||||
self._max_joint_jerk = self.rollout_fn.state_bounds.jerk[1, :] - 0.02
|
||||
# self._max_joint_jerk = self._max_joint_jerk * 0.0 + 10.0
|
||||
self._num_seeds = -1
|
||||
self._col = None
|
||||
if self.traj_evaluator is None:
|
||||
@@ -844,9 +846,9 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
result.metrics = self.interpolate_rollout.get_metrics(interpolated_trajs)
|
||||
|
||||
st_time = time.time()
|
||||
|
||||
feasible = torch.all(result.metrics.feasible, dim=-1)
|
||||
|
||||
# if self.num_seeds == 1:
|
||||
# print(result.metrics)
|
||||
if result.metrics.position_error is not None:
|
||||
converge = torch.logical_and(
|
||||
result.metrics.position_error[..., -1] <= self.position_threshold,
|
||||
@@ -874,6 +876,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
cspace_error=result.metrics.cspace_error,
|
||||
optimized_dt=opt_dt,
|
||||
raw_solution=result.action,
|
||||
raw_action=result.raw_action,
|
||||
)
|
||||
else:
|
||||
# get path length:
|
||||
@@ -904,7 +907,6 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
convergence_error = result.metrics.cspace_error[..., -1]
|
||||
else:
|
||||
raise ValueError("convergence check requires either goal_pose or goal_state")
|
||||
|
||||
error = convergence_error + smooth_cost
|
||||
error[~success] += 10000.0
|
||||
if batch_mode:
|
||||
@@ -919,6 +921,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
success = success[idx : idx + 1]
|
||||
|
||||
best_act_seq = result.action[idx]
|
||||
best_raw_action = result.raw_action[idx]
|
||||
interpolated_traj = interpolated_trajs[idx]
|
||||
position_error = rotation_error = cspace_error = None
|
||||
if result.metrics.position_error is not None:
|
||||
@@ -961,6 +964,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
smooth_label=smooth_label,
|
||||
optimized_dt=opt_dt,
|
||||
raw_solution=best_act_seq,
|
||||
raw_action=best_raw_action,
|
||||
)
|
||||
return traj_result
|
||||
|
||||
@@ -1044,7 +1048,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
if goal.goal_state is not None and self.use_cspace_seed:
|
||||
# get linear seed
|
||||
seed_traj = self.get_seeds(goal.current_state, goal.goal_state, num_seeds=num_seeds)
|
||||
# .view(batch_size, self.num_seeds, self.traj_tsteps, self.dof)
|
||||
# .view(batch_size, self.num_seeds, self.action_horizon, self.dof)
|
||||
else:
|
||||
# get start repeat seed:
|
||||
log_info("No goal state found, using current config to seed")
|
||||
@@ -1063,7 +1067,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
seed_traj = torch.cat((seed_traj, new_seeds), dim=0) # n_seed, batch, h, dof
|
||||
|
||||
seed_traj = seed_traj.view(
|
||||
total_seeds, self.traj_tsteps, self.dof
|
||||
total_seeds, self.action_horizon, self.dof
|
||||
) # n_seeds,batch, h, dof
|
||||
return seed_traj
|
||||
|
||||
@@ -1079,27 +1083,27 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
if n_seeds["linear"] > 0:
|
||||
linear_seed = self.get_linear_seed(start_state, goal_state)
|
||||
|
||||
linear_seeds = linear_seed.view(1, -1, self.traj_tsteps, self.dof).repeat(
|
||||
linear_seeds = linear_seed.view(1, -1, self.action_horizon, self.dof).repeat(
|
||||
1, n_seeds["linear"], 1, 1
|
||||
)
|
||||
seed_set.append(linear_seeds)
|
||||
if n_seeds["bias"] > 0:
|
||||
bias_seed = self.get_bias_seed(start_state, goal_state)
|
||||
bias_seeds = bias_seed.view(1, -1, self.traj_tsteps, self.dof).repeat(
|
||||
bias_seeds = bias_seed.view(1, -1, self.action_horizon, self.dof).repeat(
|
||||
1, n_seeds["bias"], 1, 1
|
||||
)
|
||||
seed_set.append(bias_seeds)
|
||||
if n_seeds["start"] > 0:
|
||||
bias_seed = self.get_start_seed(start_state)
|
||||
|
||||
bias_seeds = bias_seed.view(1, -1, self.traj_tsteps, self.dof).repeat(
|
||||
bias_seeds = bias_seed.view(1, -1, self.action_horizon, self.dof).repeat(
|
||||
1, n_seeds["start"], 1, 1
|
||||
)
|
||||
seed_set.append(bias_seeds)
|
||||
if n_seeds["goal"] > 0:
|
||||
bias_seed = self.get_start_seed(goal_state)
|
||||
|
||||
bias_seeds = bias_seed.view(1, -1, self.traj_tsteps, self.dof).repeat(
|
||||
bias_seeds = bias_seed.view(1, -1, self.action_horizon, self.dof).repeat(
|
||||
1, n_seeds["goal"], 1, 1
|
||||
)
|
||||
seed_set.append(bias_seeds)
|
||||
@@ -1142,6 +1146,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
tensor_args=self.tensor_args,
|
||||
out_traj_state=self._interpolated_traj_buffer,
|
||||
min_dt=self.traj_evaluator_config.min_dt,
|
||||
optimize_dt=self.optimize_dt,
|
||||
)
|
||||
|
||||
return state, last_tstep, opt_dt
|
||||
|
||||
@@ -63,6 +63,22 @@ class ReacherSolveState:
|
||||
if self.num_seeds is None:
|
||||
self.num_seeds = self.num_mpc_seeds
|
||||
|
||||
def clone(self):
|
||||
return ReacherSolveState(
|
||||
solve_type=self.solve_type,
|
||||
n_envs=self.n_envs,
|
||||
batch_size=self.batch_size,
|
||||
n_goalset=self.n_goalset,
|
||||
batch_env=self.batch_env,
|
||||
batch_retract=self.batch_retract,
|
||||
batch_mode=self.batch_mode,
|
||||
num_seeds=self.num_seeds,
|
||||
num_ik_seeds=self.num_ik_seeds,
|
||||
num_graph_seeds=self.num_graph_seeds,
|
||||
num_trajopt_seeds=self.num_trajopt_seeds,
|
||||
num_mpc_seeds=self.num_mpc_seeds,
|
||||
)
|
||||
|
||||
def get_batch_size(self):
|
||||
return self.num_seeds * self.batch_size
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ class WrapResult:
|
||||
metrics: Optional[RolloutMetrics] = None
|
||||
debug: Any = None
|
||||
js_action: Optional[State] = None
|
||||
raw_action: Optional[torch.Tensor] = None
|
||||
|
||||
def clone(self):
|
||||
return WrapResult(
|
||||
@@ -155,6 +156,7 @@ class WrapBase(WrapConfig):
|
||||
solve_time=self.opt_dt,
|
||||
metrics=metrics,
|
||||
debug={"steps": self.get_debug_data(), "cost": self.get_debug_cost()},
|
||||
raw_action=act_seq,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user