improved joint space planning
This commit is contained in:
@@ -121,6 +121,9 @@ class MotionGenConfig:
|
||||
#: instance of trajectory optimization solver to use for reaching joint space targets.
|
||||
js_trajopt_solver: TrajOptSolver
|
||||
|
||||
#: instance of trajectory optimization solver for final fine tuning for joint space targets.
|
||||
finetune_js_trajopt_solver: TrajOptSolver
|
||||
|
||||
#: instance of trajectory optimization solver for final fine tuning.
|
||||
finetune_trajopt_solver: TrajOptSolver
|
||||
|
||||
@@ -760,6 +763,7 @@ class MotionGenConfig:
|
||||
minimize_jerk=minimize_jerk,
|
||||
filter_robot_command=filter_robot_command,
|
||||
optimize_dt=optimize_dt,
|
||||
num_seeds=num_trajopt_noisy_seeds,
|
||||
)
|
||||
js_trajopt_solver = TrajOptSolver(js_trajopt_cfg)
|
||||
|
||||
@@ -805,6 +809,48 @@ class MotionGenConfig:
|
||||
)
|
||||
|
||||
finetune_trajopt_solver = TrajOptSolver(finetune_trajopt_cfg)
|
||||
|
||||
finetune_js_trajopt_cfg = TrajOptSolverConfig.load_from_robot_config(
|
||||
robot_cfg=robot_cfg,
|
||||
world_model=world_model,
|
||||
tensor_args=tensor_args,
|
||||
position_threshold=position_threshold,
|
||||
rotation_threshold=rotation_threshold,
|
||||
world_coll_checker=world_coll_checker,
|
||||
base_cfg_file=base_config_data,
|
||||
particle_file=particle_trajopt_file,
|
||||
gradient_file=finetune_trajopt_file,
|
||||
traj_tsteps=js_trajopt_tsteps,
|
||||
interpolation_type=interpolation_type,
|
||||
interpolation_steps=interpolation_steps,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
self_collision_check=self_collision_check,
|
||||
self_collision_opt=self_collision_opt,
|
||||
grad_trajopt_iters=grad_trajopt_iters,
|
||||
interpolation_dt=interpolation_dt,
|
||||
use_particle_opt=False,
|
||||
traj_evaluator_config=traj_evaluator_config,
|
||||
traj_evaluator=traj_evaluator,
|
||||
use_gradient_descent=use_gradient_descent,
|
||||
use_es=use_es_trajopt,
|
||||
es_learning_rate=es_trajopt_learning_rate,
|
||||
use_fixed_samples=use_trajopt_fixed_samples,
|
||||
evaluate_interpolated_trajectory=evaluate_interpolated_trajectory,
|
||||
fixed_iters=fixed_iters_trajopt,
|
||||
store_debug=store_trajopt_debug,
|
||||
collision_activation_distance=collision_activation_distance,
|
||||
trajopt_dt=js_trajopt_dt,
|
||||
store_debug_in_result=store_debug_in_result,
|
||||
smooth_weight=smooth_weight,
|
||||
cspace_threshold=cspace_threshold,
|
||||
state_finite_difference_mode=state_finite_difference_mode,
|
||||
minimize_jerk=minimize_jerk,
|
||||
filter_robot_command=filter_robot_command,
|
||||
optimize_dt=optimize_dt,
|
||||
num_seeds=num_trajopt_noisy_seeds,
|
||||
)
|
||||
finetune_js_trajopt_solver = TrajOptSolver(finetune_js_trajopt_cfg)
|
||||
|
||||
if graph_trajopt_iters is not None:
|
||||
graph_trajopt_iters = math.ceil(
|
||||
graph_trajopt_iters / finetune_trajopt_solver.solver.newton_optimizer.inner_iters
|
||||
@@ -823,6 +869,7 @@ class MotionGenConfig:
|
||||
graph_planner,
|
||||
trajopt_solver=trajopt_solver,
|
||||
js_trajopt_solver=js_trajopt_solver,
|
||||
finetune_js_trajopt_solver=finetune_js_trajopt_solver,
|
||||
finetune_trajopt_solver=finetune_trajopt_solver,
|
||||
interpolation_type=interpolation_type,
|
||||
interpolation_steps=interpolation_steps,
|
||||
@@ -951,6 +998,9 @@ class MotionGenPlanConfig:
|
||||
#: check for joint limits, self-collision, and collision with the world.
|
||||
check_start_validity: bool = True
|
||||
|
||||
#: Finetune dt scale for joint space planning.
|
||||
finetune_js_dt_scale: Optional[float] = 1.1
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post initialization checks."""
|
||||
if not self.enable_opt and not self.enable_graph:
|
||||
@@ -983,6 +1033,7 @@ class MotionGenPlanConfig:
|
||||
finetune_dt_scale=self.finetune_dt_scale,
|
||||
finetune_attempts=self.finetune_attempts,
|
||||
time_dilation_factor=self.time_dilation_factor,
|
||||
finetune_js_dt_scale=self.finetune_js_dt_scale,
|
||||
)
|
||||
|
||||
|
||||
@@ -1760,6 +1811,7 @@ class MotionGen(MotionGenConfig):
|
||||
self.ik_solver.reset_seed()
|
||||
self.graph_planner.reset_seed()
|
||||
self.trajopt_solver.reset_seed()
|
||||
self.js_trajopt_solver.reset_seed()
|
||||
|
||||
def get_retract_config(self) -> T_DOF:
|
||||
"""Returns the retract/home configuration of the robot."""
|
||||
@@ -1803,7 +1855,12 @@ class MotionGen(MotionGenConfig):
|
||||
goal_state = start_state.clone()
|
||||
goal_state.position[..., warmup_joint_index] += warmup_joint_delta
|
||||
for _ in range(3):
|
||||
self.plan_single_js(start_state, goal_state, MotionGenPlanConfig(max_attempts=1))
|
||||
self.plan_single_js(
|
||||
start_state.clone(),
|
||||
goal_state.clone(),
|
||||
MotionGenPlanConfig(max_attempts=1, enable_finetune_trajopt=True),
|
||||
)
|
||||
|
||||
if enable_graph:
|
||||
start_state = JointState.from_position(
|
||||
self.rollout_fn.dynamics_model.retract_config.view(1, -1).clone(),
|
||||
@@ -1983,7 +2040,7 @@ class MotionGen(MotionGenConfig):
|
||||
"finetune_time": 0,
|
||||
}
|
||||
result = None
|
||||
goal = Goal(goal_state=goal_state, current_state=start_state)
|
||||
# goal = Goal(goal_state=goal_state, current_state=start_state)
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.SINGLE,
|
||||
num_ik_seeds=1,
|
||||
@@ -2009,7 +2066,7 @@ class MotionGen(MotionGenConfig):
|
||||
result = self._plan_js_from_solve_state(
|
||||
solve_state, start_state, goal_state, plan_config=plan_config
|
||||
)
|
||||
time_dict["trajopt_time"] += result.solve_time
|
||||
time_dict["trajopt_time"] += result.trajopt_time
|
||||
time_dict["graph_time"] += result.graph_time
|
||||
time_dict["finetune_time"] += result.finetune_time
|
||||
time_dict["trajopt_attempts"] = n
|
||||
@@ -2160,6 +2217,7 @@ class MotionGen(MotionGenConfig):
|
||||
+ self.trajopt_solver.get_all_rollout_instances()
|
||||
+ self.finetune_trajopt_solver.get_all_rollout_instances()
|
||||
+ self.js_trajopt_solver.get_all_rollout_instances()
|
||||
+ self.finetune_js_trajopt_solver.get_all_rollout_instances()
|
||||
)
|
||||
return self._rollout_list
|
||||
|
||||
@@ -2171,6 +2229,7 @@ class MotionGen(MotionGenConfig):
|
||||
+ self.trajopt_solver.solver.get_all_rollout_instances()
|
||||
+ self.finetune_trajopt_solver.solver.get_all_rollout_instances()
|
||||
+ self.js_trajopt_solver.solver.get_all_rollout_instances()
|
||||
+ self.finetune_js_trajopt_solver.solver.get_all_rollout_instances()
|
||||
)
|
||||
return self._solver_rollout_list
|
||||
|
||||
@@ -2529,6 +2588,11 @@ class MotionGen(MotionGenConfig):
|
||||
"""Check if the pose cost metric is projected to goal frame."""
|
||||
return self.trajopt_solver.rollout_fn.goal_cost.project_distance
|
||||
|
||||
@property
|
||||
def joint_names(self) -> List[str]:
|
||||
"""Get the joint names of the robot."""
|
||||
return self.rollout_fn.joint_names
|
||||
|
||||
def update_interpolation_type(
|
||||
self,
|
||||
interpolation_type: InterpolateType,
|
||||
@@ -2548,6 +2612,7 @@ class MotionGen(MotionGenConfig):
|
||||
self.trajopt_solver.interpolation_type = interpolation_type
|
||||
self.finetune_trajopt_solver.interpolation_type = interpolation_type
|
||||
self.js_trajopt_solver.interpolation_type = interpolation_type
|
||||
self.finetune_js_trajopt_solver.interpolation_type = interpolation_type
|
||||
|
||||
def update_locked_joints(
|
||||
self, lock_joints: Dict[str, float], robot_config_dict: Union[str, Dict[Any]]
|
||||
@@ -3375,8 +3440,9 @@ class MotionGen(MotionGenConfig):
|
||||
seed_override = solve_state.num_trajopt_seeds * self.noisy_trajopt_seeds
|
||||
|
||||
finetune_time = 0
|
||||
newton_iters = None
|
||||
|
||||
for k in range(plan_config.finetune_attempts):
|
||||
newton_iters = None
|
||||
|
||||
scaled_dt = torch.clamp(
|
||||
opt_dt
|
||||
@@ -3396,7 +3462,7 @@ class MotionGen(MotionGenConfig):
|
||||
newton_iters=newton_iters,
|
||||
)
|
||||
finetune_time += traj_result.solve_time
|
||||
if torch.count_nonzero(traj_result.success) > 0:
|
||||
if torch.count_nonzero(traj_result.success) > 0 or not self.optimize_dt:
|
||||
break
|
||||
seed_traj = traj_result.optimized_seeds.detach().clone()
|
||||
newton_iters = 4
|
||||
@@ -3592,7 +3658,7 @@ class MotionGen(MotionGenConfig):
|
||||
solve_state,
|
||||
trajopt_seed_traj,
|
||||
num_seeds_override=solve_state.num_trajopt_seeds,
|
||||
newton_iters=trajopt_newton_iters + 2,
|
||||
newton_iters=trajopt_newton_iters,
|
||||
return_all_solutions=plan_config.enable_finetune_trajopt,
|
||||
trajopt_instance=self.js_trajopt_solver,
|
||||
)
|
||||
@@ -3605,29 +3671,42 @@ class MotionGen(MotionGenConfig):
|
||||
# run finetune
|
||||
if plan_config.enable_finetune_trajopt and torch.count_nonzero(traj_result.success) > 0:
|
||||
with profiler.record_function("motion_gen/finetune_trajopt"):
|
||||
seed_traj = traj_result.raw_action.clone() # solution.position.clone()
|
||||
seed_traj = seed_traj.contiguous()
|
||||
seed_traj = traj_result.raw_action.clone()
|
||||
og_solve_time = traj_result.solve_time
|
||||
opt_dt = traj_result.optimized_dt
|
||||
opt_dt = torch.min(opt_dt[traj_result.success])
|
||||
finetune_time = 0
|
||||
newton_iters = None
|
||||
for k in range(plan_config.finetune_attempts):
|
||||
scaled_dt = torch.clamp(
|
||||
opt_dt
|
||||
* plan_config.finetune_js_dt_scale
|
||||
* (plan_config.finetune_dt_decay ** (k)),
|
||||
self.js_trajopt_solver.minimum_trajectory_dt,
|
||||
)
|
||||
|
||||
scaled_dt = torch.clamp(
|
||||
torch.max(traj_result.optimized_dt[traj_result.success]),
|
||||
self.trajopt_solver.minimum_trajectory_dt,
|
||||
)
|
||||
og_dt = self.js_trajopt_solver.solver_dt.clone()
|
||||
self.js_trajopt_solver.update_solver_dt(scaled_dt.item())
|
||||
traj_result = self._solve_trajopt_from_solve_state(
|
||||
goal,
|
||||
solve_state,
|
||||
seed_traj,
|
||||
trajopt_instance=self.js_trajopt_solver,
|
||||
num_seeds_override=solve_state.num_trajopt_seeds,
|
||||
newton_iters=trajopt_newton_iters + 4,
|
||||
)
|
||||
self.js_trajopt_solver.update_solver_dt(og_dt)
|
||||
if self.optimize_dt:
|
||||
self.finetune_js_trajopt_solver.update_solver_dt(scaled_dt.item())
|
||||
traj_result = self._solve_trajopt_from_solve_state(
|
||||
goal,
|
||||
solve_state,
|
||||
seed_traj,
|
||||
trajopt_instance=self.finetune_js_trajopt_solver,
|
||||
num_seeds_override=solve_state.num_trajopt_seeds,
|
||||
newton_iters=newton_iters,
|
||||
return_all_solutions=False,
|
||||
)
|
||||
|
||||
result.finetune_time = traj_result.solve_time
|
||||
finetune_time += traj_result.solve_time
|
||||
if torch.count_nonzero(traj_result.success) > 0 or not self.optimize_dt:
|
||||
break
|
||||
|
||||
traj_result.solve_time = og_solve_time
|
||||
seed_traj = traj_result.optimized_seeds.detach().clone()
|
||||
newton_iters = 4
|
||||
|
||||
result.finetune_time = finetune_time
|
||||
|
||||
traj_result.solve_time = og_solve_time
|
||||
if self.store_debug_in_result:
|
||||
result.debug_info["finetune_trajopt_result"] = traj_result
|
||||
if torch.count_nonzero(traj_result.success) == 0:
|
||||
@@ -3638,7 +3717,6 @@ class MotionGen(MotionGenConfig):
|
||||
result.trajopt_time = traj_result.solve_time
|
||||
result.trajopt_attempts = 1
|
||||
result.success = traj_result.success
|
||||
|
||||
result.interpolated_plan = traj_result.interpolated_solution.trim_trajectory(
|
||||
0, traj_result.path_buffer_last_tstep[0]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user