update to 0.6.2
This commit is contained in:
@@ -122,7 +122,7 @@ class MotionGenConfig:
|
||||
interpolation_dt: float = 0.01
|
||||
|
||||
#: scale initial dt by this value to finetune trajectory optimization.
|
||||
finetune_dt_scale: float = 1.05
|
||||
finetune_dt_scale: float = 0.98
|
||||
|
||||
@staticmethod
|
||||
@profiler.record_function("motion_gen_config/load_from_robot_config")
|
||||
@@ -192,12 +192,13 @@ class MotionGenConfig:
|
||||
smooth_weight: List[float] = None,
|
||||
finetune_smooth_weight: Optional[List[float]] = None,
|
||||
state_finite_difference_mode: Optional[str] = None,
|
||||
finetune_dt_scale: float = 1.05,
|
||||
finetune_dt_scale: float = 0.98,
|
||||
maximum_trajectory_time: Optional[float] = None,
|
||||
maximum_trajectory_dt: float = 0.1,
|
||||
velocity_scale: Optional[Union[List[float], float]] = None,
|
||||
acceleration_scale: Optional[Union[List[float], float]] = None,
|
||||
jerk_scale: Optional[Union[List[float], float]] = None,
|
||||
optimize_dt: bool = True,
|
||||
):
|
||||
"""Load motion generation configuration from robot and world configurations.
|
||||
|
||||
@@ -279,7 +280,11 @@ class MotionGenConfig:
|
||||
"""
|
||||
|
||||
init_warp(tensor_args=tensor_args)
|
||||
|
||||
if js_trajopt_tsteps is not None:
|
||||
log_warn("js_trajopt_tsteps is deprecated, use trajopt_tsteps instead.")
|
||||
trajopt_tsteps = js_trajopt_tsteps
|
||||
if trajopt_tsteps is not None:
|
||||
js_trajopt_tsteps = trajopt_tsteps
|
||||
if velocity_scale is not None and isinstance(velocity_scale, float):
|
||||
velocity_scale = [velocity_scale]
|
||||
|
||||
@@ -318,6 +323,8 @@ class MotionGenConfig:
|
||||
|
||||
if isinstance(robot_cfg, str):
|
||||
robot_cfg = load_yaml(join_path(get_robot_configs_path(), robot_cfg))["robot_cfg"]
|
||||
elif isinstance(robot_cfg, Dict) and "robot_cfg" in robot_cfg.keys():
|
||||
robot_cfg = robot_cfg["robot_cfg"]
|
||||
if isinstance(robot_cfg, RobotConfig):
|
||||
if ee_link_name is not None:
|
||||
log_error("ee link cannot be changed after creating RobotConfig")
|
||||
@@ -445,6 +452,7 @@ class MotionGenConfig:
|
||||
state_finite_difference_mode=state_finite_difference_mode,
|
||||
filter_robot_command=filter_robot_command,
|
||||
minimize_jerk=minimize_jerk,
|
||||
optimize_dt=optimize_dt,
|
||||
)
|
||||
trajopt_solver = TrajOptSolver(trajopt_cfg)
|
||||
|
||||
@@ -465,8 +473,6 @@ class MotionGenConfig:
|
||||
self_collision_check=self_collision_check,
|
||||
self_collision_opt=self_collision_opt,
|
||||
grad_trajopt_iters=grad_trajopt_iters,
|
||||
# num_seeds=num_trajopt_noisy_seeds,
|
||||
# seed_ratio=trajopt_seed_ratio,
|
||||
interpolation_dt=interpolation_dt,
|
||||
use_particle_opt=trajopt_particle_opt,
|
||||
traj_evaluator_config=traj_evaluator_config,
|
||||
@@ -486,6 +492,7 @@ class MotionGenConfig:
|
||||
state_finite_difference_mode=state_finite_difference_mode,
|
||||
minimize_jerk=minimize_jerk,
|
||||
filter_robot_command=filter_robot_command,
|
||||
optimize_dt=optimize_dt,
|
||||
)
|
||||
js_trajopt_solver = TrajOptSolver(js_trajopt_cfg)
|
||||
|
||||
@@ -523,6 +530,7 @@ class MotionGenConfig:
|
||||
trim_steps=trim_steps,
|
||||
use_gradient_descent=use_gradient_descent,
|
||||
filter_robot_command=filter_robot_command,
|
||||
optimize_dt=optimize_dt,
|
||||
)
|
||||
|
||||
finetune_trajopt_solver = TrajOptSolver(finetune_trajopt_cfg)
|
||||
@@ -748,7 +756,9 @@ class MotionGenResult:
|
||||
|
||||
@property
|
||||
def motion_time(self):
|
||||
return self.optimized_dt * (self.optimized_plan.position.shape[-2] - 1)
|
||||
# -2 as last three timesteps have the same value
|
||||
# 0, 1 also have the same position value.
|
||||
return self.optimized_dt * (self.optimized_plan.position.shape[-2] - 1 - 2 - 1)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -762,7 +772,7 @@ class MotionGenPlanConfig:
|
||||
enable_graph_attempt: Optional[int] = 3
|
||||
disable_graph_attempt: Optional[int] = None
|
||||
ik_fail_return: Optional[int] = None
|
||||
partial_ik_opt: bool = True
|
||||
partial_ik_opt: bool = False
|
||||
num_ik_seeds: Optional[int] = None
|
||||
num_graph_seeds: Optional[int] = None
|
||||
num_trajopt_seeds: Optional[int] = None
|
||||
@@ -770,6 +780,7 @@ class MotionGenPlanConfig:
|
||||
fail_on_invalid_query: bool = True
|
||||
#: enables retiming trajectory optimization, useful for getting low jerk trajectories.
|
||||
enable_finetune_trajopt: bool = True
|
||||
parallel_finetune: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.enable_opt and not self.enable_graph:
|
||||
@@ -1103,8 +1114,11 @@ class MotionGen(MotionGenConfig):
|
||||
if plan_config.enable_graph:
|
||||
raise ValueError("Graph Search / Geometric Planner not supported in batch_env mode")
|
||||
|
||||
if plan_config.enable_graph:
|
||||
log_info("Batch mode enable graph is only supported with num_graph_seeds==1")
|
||||
if plan_config.enable_graph or (
|
||||
plan_config.enable_graph_attempt is not None
|
||||
and plan_config.max_attempts >= plan_config.enable_graph_attempt
|
||||
):
|
||||
log_warn("Batch mode enable graph is only supported with num_graph_seeds==1")
|
||||
plan_config.num_trajopt_seeds = 1
|
||||
plan_config.num_graph_seeds = 1
|
||||
solve_state.num_trajopt_seeds = 1
|
||||
@@ -1316,7 +1330,6 @@ class MotionGen(MotionGenConfig):
|
||||
trajopt_seed_traj = None
|
||||
trajopt_seed_success = None
|
||||
trajopt_newton_iters = None
|
||||
|
||||
graph_success = 0
|
||||
if len(start_state.shape) == 1:
|
||||
log_error("Joint state should be not a vector (dof) should be (bxdof)")
|
||||
@@ -1330,6 +1343,7 @@ class MotionGen(MotionGenConfig):
|
||||
plan_config.partial_ik_opt,
|
||||
link_poses,
|
||||
)
|
||||
|
||||
if not plan_config.enable_graph and plan_config.partial_ik_opt:
|
||||
ik_result.success[:] = True
|
||||
|
||||
@@ -1364,7 +1378,7 @@ class MotionGen(MotionGenConfig):
|
||||
if plan_config.enable_graph:
|
||||
interpolation_steps = None
|
||||
if plan_config.enable_opt:
|
||||
interpolation_steps = self.trajopt_solver.traj_tsteps - 4
|
||||
interpolation_steps = self.trajopt_solver.action_horizon
|
||||
log_info("MG: running GP")
|
||||
graph_result = self.graph_search(start_config, goal_config, interpolation_steps)
|
||||
trajopt_seed_success = graph_result.success
|
||||
@@ -1378,6 +1392,8 @@ class MotionGen(MotionGenConfig):
|
||||
|
||||
result.used_graph = True
|
||||
if plan_config.enable_opt:
|
||||
# print(result.graph_plan.position.shape, interpolation_steps,
|
||||
# graph_result.path_buffer_last_tstep)
|
||||
trajopt_seed = (
|
||||
result.graph_plan.position.view(
|
||||
1, # solve_state.batch_size,
|
||||
@@ -1389,12 +1405,11 @@ class MotionGen(MotionGenConfig):
|
||||
.contiguous()
|
||||
)
|
||||
trajopt_seed_traj = torch.zeros(
|
||||
(trajopt_seed.shape[0], 1, self.trajopt_solver.traj_tsteps, self._dof),
|
||||
(trajopt_seed.shape[0], 1, self.trajopt_solver.action_horizon, self._dof),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
)
|
||||
trajopt_seed_traj[:, :, :-4, :] = trajopt_seed
|
||||
trajopt_seed_traj[:, :, -4:, :] = trajopt_seed_traj[:, :, -5:-4, :]
|
||||
trajopt_seed_traj[:, :, :interpolation_steps, :] = trajopt_seed
|
||||
trajopt_seed_success = ik_result.success.clone()
|
||||
trajopt_seed_success[ik_result.success] = graph_result.success
|
||||
|
||||
@@ -1497,7 +1512,7 @@ class MotionGen(MotionGenConfig):
|
||||
trajopt_seed_traj = trajopt_seed_traj.view(
|
||||
solve_state.num_trajopt_seeds * self.noisy_trajopt_seeds,
|
||||
solve_state.batch_size,
|
||||
self.trajopt_solver.traj_tsteps,
|
||||
self.trajopt_solver.action_horizon,
|
||||
self._dof,
|
||||
).contiguous()
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
@@ -1511,31 +1526,27 @@ class MotionGen(MotionGenConfig):
|
||||
trajopt_seed_traj,
|
||||
num_seeds_override=solve_state.num_trajopt_seeds * self.noisy_trajopt_seeds,
|
||||
newton_iters=trajopt_newton_iters,
|
||||
return_all_solutions=plan_config.parallel_finetune,
|
||||
)
|
||||
if False and not traj_result.success.item():
|
||||
# pose_convergence = traj_result.position_error < self.
|
||||
print(
|
||||
traj_result.position_error.item(),
|
||||
traj_result.rotation_error.item(),
|
||||
torch.count_nonzero(~traj_result.metrics.feasible[0]).item(),
|
||||
torch.count_nonzero(~traj_result.metrics.feasible[1]).item(),
|
||||
traj_result.optimized_dt.item(),
|
||||
)
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
self.trajopt_solver.interpolation_type = og_value
|
||||
# self.trajopt_solver.compute_metrics(not og_evaluate, og_evaluate)
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["trajopt_result"] = traj_result
|
||||
# run finetune
|
||||
if plan_config.enable_finetune_trajopt and traj_result.success[0].item():
|
||||
if plan_config.enable_finetune_trajopt and torch.count_nonzero(traj_result.success) > 0:
|
||||
with profiler.record_function("motion_gen/finetune_trajopt"):
|
||||
seed_traj = traj_result.solution.position.clone()
|
||||
seed_traj = torch.roll(seed_traj, -2, dims=-2)
|
||||
# seed_traj[..., -2:, :] = seed_traj[..., -3, :]
|
||||
seed_traj = traj_result.raw_action.clone() # solution.position.clone()
|
||||
seed_traj = seed_traj.contiguous()
|
||||
og_solve_time = traj_result.solve_time
|
||||
seed_override = 1
|
||||
opt_dt = traj_result.optimized_dt
|
||||
|
||||
if plan_config.parallel_finetune:
|
||||
opt_dt = torch.min(opt_dt[traj_result.success])
|
||||
seed_override = solve_state.num_trajopt_seeds * self.noisy_trajopt_seeds
|
||||
scaled_dt = torch.clamp(
|
||||
traj_result.optimized_dt * self.finetune_dt_scale,
|
||||
opt_dt * self.finetune_dt_scale,
|
||||
self.trajopt_solver.interpolation_dt,
|
||||
)
|
||||
self.finetune_trajopt_solver.update_solver_dt(scaled_dt.item())
|
||||
@@ -1545,26 +1556,16 @@ class MotionGen(MotionGenConfig):
|
||||
solve_state,
|
||||
seed_traj,
|
||||
trajopt_instance=self.finetune_trajopt_solver,
|
||||
num_seeds_override=1,
|
||||
num_seeds_override=seed_override,
|
||||
)
|
||||
if False and not traj_result.success.item():
|
||||
print(
|
||||
traj_result.position_error.item(),
|
||||
traj_result.rotation_error.item(),
|
||||
torch.count_nonzero(~traj_result.metrics.feasible).item(),
|
||||
traj_result.optimized_dt.item(),
|
||||
)
|
||||
# if not traj_result.success.item():
|
||||
# #print(traj_result.metrics.constraint)
|
||||
# print(traj_result.position_error.item() * 100.0,
|
||||
# traj_result.rotation_error.item() * 100.0)
|
||||
|
||||
result.finetune_time = traj_result.solve_time
|
||||
|
||||
traj_result.solve_time = og_solve_time
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["finetune_trajopt_result"] = traj_result
|
||||
|
||||
elif plan_config.enable_finetune_trajopt:
|
||||
traj_result.success = traj_result.success[0:1]
|
||||
result.solve_time += traj_result.solve_time + result.finetune_time
|
||||
result.trajopt_time = traj_result.solve_time
|
||||
result.trajopt_attempts = 1
|
||||
@@ -1576,12 +1577,220 @@ class MotionGen(MotionGenConfig):
|
||||
result.interpolated_plan = traj_result.interpolated_solution.trim_trajectory(
|
||||
0, traj_result.path_buffer_last_tstep[0]
|
||||
)
|
||||
# print(ik_result.position_error[ik_result.success] * 1000.0)
|
||||
# print(traj_result.position_error * 1000.0)
|
||||
# exit()
|
||||
result.interpolation_dt = self.trajopt_solver.interpolation_dt
|
||||
result.path_buffer_last_tstep = traj_result.path_buffer_last_tstep
|
||||
result.position_error = traj_result.position_error
|
||||
result.rotation_error = traj_result.rotation_error
|
||||
result.optimized_dt = traj_result.optimized_dt
|
||||
result.optimized_plan = traj_result.solution
|
||||
return result
|
||||
|
||||
def _plan_js_from_solve_state(
|
||||
self,
|
||||
solve_state: ReacherSolveState,
|
||||
start_state: JointState,
|
||||
goal_state: JointState,
|
||||
plan_config: MotionGenPlanConfig = MotionGenPlanConfig(),
|
||||
) -> MotionGenResult:
|
||||
trajopt_seed_traj = None
|
||||
trajopt_seed_success = None
|
||||
trajopt_newton_iters = None
|
||||
|
||||
graph_success = 0
|
||||
if len(start_state.shape) == 1:
|
||||
log_error("Joint state should be not a vector (dof) should be (bxdof)")
|
||||
|
||||
result = MotionGenResult(cspace_error=torch.zeros((1), device=self.tensor_args.device))
|
||||
# do graph search:
|
||||
if plan_config.enable_graph:
|
||||
start_config = torch.zeros(
|
||||
(solve_state.num_graph_seeds, self.js_trajopt_solver.dof),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
)
|
||||
goal_config = start_config.clone()
|
||||
start_config[:] = start_state.position
|
||||
goal_config[:] = goal_state.position
|
||||
interpolation_steps = None
|
||||
if plan_config.enable_opt:
|
||||
interpolation_steps = self.js_trajopt_solver.action_horizon
|
||||
log_info("MG: running GP")
|
||||
graph_result = self.graph_search(start_config, goal_config, interpolation_steps)
|
||||
trajopt_seed_success = graph_result.success
|
||||
|
||||
graph_success = torch.count_nonzero(graph_result.success).item()
|
||||
result.graph_time = graph_result.solve_time
|
||||
result.solve_time += graph_result.solve_time
|
||||
if graph_success > 0:
|
||||
result.graph_plan = graph_result.interpolated_plan
|
||||
result.interpolated_plan = graph_result.interpolated_plan
|
||||
|
||||
result.used_graph = True
|
||||
if plan_config.enable_opt:
|
||||
trajopt_seed = (
|
||||
result.graph_plan.position.view(
|
||||
1, # solve_state.batch_size,
|
||||
graph_success, # solve_state.num_trajopt_seeds,
|
||||
interpolation_steps,
|
||||
self._dof,
|
||||
)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
trajopt_seed_traj = torch.zeros(
|
||||
(trajopt_seed.shape[0], 1, self.trajopt_solver.action_horizon, self._dof),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
)
|
||||
trajopt_seed_traj[:, :, :interpolation_steps, :] = trajopt_seed
|
||||
trajopt_seed_success = graph_result.success
|
||||
|
||||
trajopt_seed_success = trajopt_seed_success.view(
|
||||
1, solve_state.num_trajopt_seeds
|
||||
)
|
||||
trajopt_newton_iters = self.graph_trajopt_iters
|
||||
else:
|
||||
_, idx = torch.topk(
|
||||
graph_result.path_length[graph_result.success], k=1, largest=False
|
||||
)
|
||||
result.interpolated_plan = result.interpolated_plan[idx].squeeze(0)
|
||||
result.optimized_dt = self.tensor_args.to_device(self.interpolation_dt)
|
||||
result.optimized_plan = result.interpolated_plan[
|
||||
: graph_result.path_buffer_last_tstep[idx.item()]
|
||||
]
|
||||
idx = idx.view(-1) + self._batch_col
|
||||
result.cspace_error = torch.zeros((1), device=self.tensor_args.device)
|
||||
|
||||
result.path_buffer_last_tstep = graph_result.path_buffer_last_tstep[
|
||||
idx.item() : idx.item() + 1
|
||||
]
|
||||
result.success = torch.as_tensor([True], device=self.tensor_args.device)
|
||||
return result
|
||||
else:
|
||||
result.success = torch.as_tensor([False], device=self.tensor_args.device)
|
||||
result.status = "Graph Fail"
|
||||
if not graph_result.valid_query:
|
||||
result.valid_query = False
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["graph_debug"] = graph_result.debug_info
|
||||
return result
|
||||
if plan_config.need_graph_success:
|
||||
return result
|
||||
|
||||
# do trajopt:
|
||||
if plan_config.enable_opt:
|
||||
with profiler.record_function("motion_gen/setup_trajopt_seeds"):
|
||||
# self._trajopt_goal_config[:, :ik_success] = goal_config
|
||||
|
||||
goal = Goal(
|
||||
current_state=start_state,
|
||||
goal_state=goal_state,
|
||||
)
|
||||
|
||||
if trajopt_seed_traj is None or graph_success < solve_state.num_trajopt_seeds * 1:
|
||||
seed_goal = Goal(
|
||||
current_state=start_state.repeat_seeds(solve_state.num_trajopt_seeds),
|
||||
goal_state=goal_state.repeat_seeds(solve_state.num_trajopt_seeds),
|
||||
)
|
||||
if trajopt_seed_traj is not None:
|
||||
trajopt_seed_traj = trajopt_seed_traj.transpose(0, 1).contiguous()
|
||||
# batch, num_seeds, h, dof
|
||||
if trajopt_seed_success.shape[1] < self.js_trajopt_solver.num_seeds:
|
||||
trajopt_seed_success_new = torch.zeros(
|
||||
(1, solve_state.num_trajopt_seeds),
|
||||
device=self.tensor_args.device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
trajopt_seed_success_new[
|
||||
0, : trajopt_seed_success.shape[1]
|
||||
] = trajopt_seed_success
|
||||
trajopt_seed_success = trajopt_seed_success_new
|
||||
# create seeds here:
|
||||
trajopt_seed_traj = self.js_trajopt_solver.get_seed_set(
|
||||
seed_goal,
|
||||
trajopt_seed_traj, # batch, num_seeds, h, dof
|
||||
num_seeds=1,
|
||||
batch_mode=False,
|
||||
seed_success=trajopt_seed_success,
|
||||
)
|
||||
trajopt_seed_traj = trajopt_seed_traj.view(
|
||||
self.js_trajopt_solver.num_seeds * 1,
|
||||
1,
|
||||
self.trajopt_solver.action_horizon,
|
||||
self._dof,
|
||||
).contiguous()
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
og_value = self.trajopt_solver.interpolation_type
|
||||
self.js_trajopt_solver.interpolation_type = InterpolateType.LINEAR_CUDA
|
||||
with profiler.record_function("motion_gen/trajopt"):
|
||||
log_info("MG: running TO")
|
||||
traj_result = self._solve_trajopt_from_solve_state(
|
||||
goal,
|
||||
solve_state,
|
||||
trajopt_seed_traj,
|
||||
num_seeds_override=solve_state.num_trajopt_seeds * 1,
|
||||
newton_iters=trajopt_newton_iters,
|
||||
return_all_solutions=plan_config.enable_finetune_trajopt,
|
||||
trajopt_instance=self.js_trajopt_solver,
|
||||
)
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
self.trajopt_solver.interpolation_type = og_value
|
||||
# self.trajopt_solver.compute_metrics(not og_evaluate, og_evaluate)
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["trajopt_result"] = traj_result
|
||||
if torch.count_nonzero(traj_result.success) == 0:
|
||||
result.status = "TrajOpt Fail"
|
||||
# run finetune
|
||||
if plan_config.enable_finetune_trajopt and torch.count_nonzero(traj_result.success) > 0:
|
||||
with profiler.record_function("motion_gen/finetune_trajopt"):
|
||||
seed_traj = traj_result.raw_action.clone() # solution.position.clone()
|
||||
seed_traj = seed_traj.contiguous()
|
||||
og_solve_time = traj_result.solve_time
|
||||
|
||||
scaled_dt = torch.clamp(
|
||||
torch.max(traj_result.optimized_dt[traj_result.success]),
|
||||
self.trajopt_solver.interpolation_dt,
|
||||
)
|
||||
og_dt = self.js_trajopt_solver.solver_dt
|
||||
self.js_trajopt_solver.update_solver_dt(scaled_dt.item())
|
||||
|
||||
traj_result = self._solve_trajopt_from_solve_state(
|
||||
goal,
|
||||
solve_state,
|
||||
seed_traj,
|
||||
trajopt_instance=self.js_trajopt_solver,
|
||||
num_seeds_override=solve_state.num_trajopt_seeds * self.noisy_trajopt_seeds,
|
||||
)
|
||||
self.js_trajopt_solver.update_solver_dt(og_dt)
|
||||
|
||||
result.finetune_time = traj_result.solve_time
|
||||
|
||||
traj_result.solve_time = og_solve_time
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["finetune_trajopt_result"] = traj_result
|
||||
if torch.count_nonzero(traj_result.success) == 0:
|
||||
result.status = "Finetune Fail"
|
||||
elif plan_config.enable_finetune_trajopt:
|
||||
traj_result.success = traj_result.success[0:1]
|
||||
result.solve_time += traj_result.solve_time + result.finetune_time
|
||||
result.trajopt_time = traj_result.solve_time
|
||||
result.trajopt_attempts = 1
|
||||
result.success = traj_result.success
|
||||
|
||||
result.interpolated_plan = traj_result.interpolated_solution.trim_trajectory(
|
||||
0, traj_result.path_buffer_last_tstep[0]
|
||||
)
|
||||
# print(ik_result.position_error[ik_result.success] * 1000.0)
|
||||
# print(traj_result.position_error * 1000.0)
|
||||
# exit()
|
||||
result.interpolation_dt = self.trajopt_solver.interpolation_dt
|
||||
result.path_buffer_last_tstep = traj_result.path_buffer_last_tstep
|
||||
result.cspace_error = traj_result.cspace_error
|
||||
result.optimized_dt = traj_result.optimized_dt
|
||||
result.optimized_plan = traj_result.solution
|
||||
|
||||
return result
|
||||
|
||||
@@ -1644,7 +1853,7 @@ class MotionGen(MotionGenConfig):
|
||||
if plan_config.enable_graph:
|
||||
interpolation_steps = None
|
||||
if plan_config.enable_opt:
|
||||
interpolation_steps = self.trajopt_solver.traj_tsteps - 4
|
||||
interpolation_steps = self.trajopt_solver.action_horizon
|
||||
|
||||
start_graph_state = start_state.repeat_seeds(ik_out_seeds)
|
||||
start_config = start_graph_state.position[ik_result.success.view(-1)].view(
|
||||
@@ -1662,23 +1871,17 @@ class MotionGen(MotionGenConfig):
|
||||
result.used_graph = True
|
||||
|
||||
if plan_config.enable_opt:
|
||||
trajopt_seed = (
|
||||
result.graph_plan.position.view(
|
||||
1, # solve_state.batch_size,
|
||||
graph_success, # solve_state.num_trajopt_seeds,
|
||||
interpolation_steps,
|
||||
self._dof,
|
||||
)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
trajopt_seed = result.graph_plan.position.view(
|
||||
graph_success, # solve_state.num_trajopt_seeds,
|
||||
interpolation_steps,
|
||||
self._dof,
|
||||
).contiguous()
|
||||
trajopt_seed_traj = torch.zeros(
|
||||
(trajopt_seed.shape[0], 1, self.trajopt_solver.traj_tsteps, self._dof),
|
||||
(1, trajopt_seed.shape[0], self.trajopt_solver.action_horizon, self._dof),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
)
|
||||
trajopt_seed_traj[:, :, :-4, :] = trajopt_seed
|
||||
trajopt_seed_traj[:, :, -4:, :] = trajopt_seed_traj[:, :, -5:-4, :]
|
||||
trajopt_seed_traj[0, :, :interpolation_steps, :] = trajopt_seed
|
||||
|
||||
trajopt_seed_success = ik_result.success.clone()
|
||||
trajopt_seed_success[ik_result.success] = graph_result.success
|
||||
@@ -1766,14 +1969,54 @@ class MotionGen(MotionGenConfig):
|
||||
trajopt_seed_traj = trajopt_seed_traj.view(
|
||||
solve_state.num_trajopt_seeds,
|
||||
solve_state.batch_size,
|
||||
self.trajopt_solver.traj_tsteps,
|
||||
self.trajopt_solver.action_horizon,
|
||||
self._dof,
|
||||
).contiguous()
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
og_value = self.trajopt_solver.interpolation_type
|
||||
self.trajopt_solver.interpolation_type = InterpolateType.LINEAR_CUDA
|
||||
traj_result = self._solve_trajopt_from_solve_state(
|
||||
goal, solve_state, trajopt_seed_traj, newton_iters=trajopt_newton_iters
|
||||
goal,
|
||||
solve_state,
|
||||
trajopt_seed_traj,
|
||||
newton_iters=trajopt_newton_iters,
|
||||
return_all_solutions=True,
|
||||
)
|
||||
|
||||
# output of traj result will have 1 solution per batch
|
||||
|
||||
# run finetune opt on 1 solution per batch:
|
||||
if plan_config.enable_finetune_trajopt:
|
||||
self.trajopt_solver.interpolation_type = og_value
|
||||
# self.trajopt_solver.compute_metrics(not og_evaluate, og_evaluate)
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["trajopt_result"] = traj_result
|
||||
# run finetune
|
||||
if plan_config.enable_finetune_trajopt and torch.count_nonzero(traj_result.success) > 0:
|
||||
with profiler.record_function("motion_gen/finetune_trajopt"):
|
||||
seed_traj = traj_result.raw_action.clone() # solution.position.clone()
|
||||
seed_traj = seed_traj.contiguous()
|
||||
og_solve_time = traj_result.solve_time
|
||||
|
||||
scaled_dt = torch.clamp(
|
||||
torch.max(traj_result.optimized_dt[traj_result.success])
|
||||
* self.finetune_dt_scale,
|
||||
self.trajopt_solver.interpolation_dt,
|
||||
)
|
||||
self.finetune_trajopt_solver.update_solver_dt(scaled_dt.item())
|
||||
traj_result = self._solve_trajopt_from_solve_state(
|
||||
goal,
|
||||
solve_state,
|
||||
seed_traj,
|
||||
trajopt_instance=self.finetune_trajopt_solver,
|
||||
num_seeds_override=solve_state.num_trajopt_seeds,
|
||||
)
|
||||
|
||||
result.finetune_time = traj_result.solve_time
|
||||
|
||||
traj_result.solve_time = og_solve_time
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["finetune_trajopt_result"] = traj_result
|
||||
result.success = traj_result.success
|
||||
|
||||
result.interpolated_plan = traj_result.interpolated_solution
|
||||
@@ -1837,6 +2080,7 @@ class MotionGen(MotionGenConfig):
|
||||
batch: Optional[int] = None,
|
||||
warmup_js_trajopt: bool = True,
|
||||
batch_env_mode: bool = False,
|
||||
parallel_finetune: bool = False,
|
||||
):
|
||||
log_info("Warmup")
|
||||
|
||||
@@ -1854,7 +2098,11 @@ class MotionGen(MotionGenConfig):
|
||||
self.plan_single(
|
||||
start_state,
|
||||
retract_pose,
|
||||
MotionGenPlanConfig(max_attempts=1, enable_finetune_trajopt=True),
|
||||
MotionGenPlanConfig(
|
||||
max_attempts=1,
|
||||
enable_finetune_trajopt=True,
|
||||
parallel_finetune=parallel_finetune,
|
||||
),
|
||||
link_poses=link_poses,
|
||||
)
|
||||
if enable_graph:
|
||||
@@ -1867,7 +2115,10 @@ class MotionGen(MotionGenConfig):
|
||||
start_state,
|
||||
retract_pose,
|
||||
MotionGenPlanConfig(
|
||||
max_attempts=1, enable_finetune_trajopt=True, enable_graph=enable_graph
|
||||
max_attempts=1,
|
||||
enable_finetune_trajopt=True,
|
||||
enable_graph=enable_graph,
|
||||
parallel_finetune=parallel_finetune,
|
||||
),
|
||||
link_poses=link_poses,
|
||||
)
|
||||
@@ -1925,14 +2176,24 @@ class MotionGen(MotionGenConfig):
|
||||
}
|
||||
result = None
|
||||
goal = Goal(goal_state=goal_state, current_state=start_state)
|
||||
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.SINGLE,
|
||||
num_ik_seeds=1,
|
||||
num_trajopt_seeds=self.js_trajopt_solver.num_seeds,
|
||||
num_graph_seeds=self.js_trajopt_solver.num_seeds,
|
||||
batch_size=1,
|
||||
n_envs=1,
|
||||
n_goalset=1,
|
||||
)
|
||||
for n in range(plan_config.max_attempts):
|
||||
traj_result = self.js_trajopt_solver.solve_single(goal)
|
||||
traj_result = self._plan_js_from_solve_state(
|
||||
solve_state, start_state, goal_state, plan_config=plan_config
|
||||
)
|
||||
time_dict["trajopt_time"] += traj_result.solve_time
|
||||
time_dict["trajopt_attempts"] = n
|
||||
|
||||
if result is None:
|
||||
result = MotionGenResult(success=traj_result.success)
|
||||
result = traj_result
|
||||
|
||||
if traj_result.success.item():
|
||||
break
|
||||
@@ -1940,25 +2201,7 @@ class MotionGen(MotionGenConfig):
|
||||
result.solve_time = time_dict["trajopt_time"]
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info = {"trajopt_result": traj_result}
|
||||
status = None
|
||||
if not traj_result.success.item():
|
||||
# print(traj_result.cspace_error, traj_result.success)
|
||||
status = ""
|
||||
if traj_result.cspace_error.item() >= self.js_trajopt_solver.cspace_threshold:
|
||||
status += " Fail: C-SPACE Convergence"
|
||||
if torch.count_nonzero(~traj_result.metrics.feasible).item() > 0:
|
||||
status += " Fail: Constraints"
|
||||
# print(traj_result.metrics.feasible)
|
||||
|
||||
result.status = status
|
||||
result.position_error = traj_result.position_error
|
||||
result.rotation_error = traj_result.rotation_error
|
||||
result.cspace_error = traj_result.cspace_error
|
||||
result.optimized_dt = traj_result.optimized_dt
|
||||
result.interpolated_plan = traj_result.interpolated_solution
|
||||
result.optimized_plan = traj_result.solution
|
||||
result.path_buffer_last_tstep = traj_result.path_buffer_last_tstep
|
||||
result.success = traj_result.success
|
||||
return result
|
||||
|
||||
def plan(
|
||||
|
||||
Reference in New Issue
Block a user