improved joint space planning

This commit is contained in:
Balakumar Sundaralingam
2024-05-30 14:42:22 -07:00
parent 3bfed9d773
commit 0c51dd2da8
28 changed files with 1135 additions and 213 deletions

View File

@@ -121,6 +121,9 @@ class MotionGenConfig:
#: instance of trajectory optimization solver to use for reaching joint space targets.
js_trajopt_solver: TrajOptSolver
#: instance of trajectory optimization solver for final fine tuning for joint space targets.
finetune_js_trajopt_solver: TrajOptSolver
#: instance of trajectory optimization solver for final fine tuning.
finetune_trajopt_solver: TrajOptSolver
@@ -760,6 +763,7 @@ class MotionGenConfig:
minimize_jerk=minimize_jerk,
filter_robot_command=filter_robot_command,
optimize_dt=optimize_dt,
num_seeds=num_trajopt_noisy_seeds,
)
js_trajopt_solver = TrajOptSolver(js_trajopt_cfg)
@@ -805,6 +809,48 @@ class MotionGenConfig:
)
finetune_trajopt_solver = TrajOptSolver(finetune_trajopt_cfg)
finetune_js_trajopt_cfg = TrajOptSolverConfig.load_from_robot_config(
robot_cfg=robot_cfg,
world_model=world_model,
tensor_args=tensor_args,
position_threshold=position_threshold,
rotation_threshold=rotation_threshold,
world_coll_checker=world_coll_checker,
base_cfg_file=base_config_data,
particle_file=particle_trajopt_file,
gradient_file=finetune_trajopt_file,
traj_tsteps=js_trajopt_tsteps,
interpolation_type=interpolation_type,
interpolation_steps=interpolation_steps,
use_cuda_graph=use_cuda_graph,
self_collision_check=self_collision_check,
self_collision_opt=self_collision_opt,
grad_trajopt_iters=grad_trajopt_iters,
interpolation_dt=interpolation_dt,
use_particle_opt=False,
traj_evaluator_config=traj_evaluator_config,
traj_evaluator=traj_evaluator,
use_gradient_descent=use_gradient_descent,
use_es=use_es_trajopt,
es_learning_rate=es_trajopt_learning_rate,
use_fixed_samples=use_trajopt_fixed_samples,
evaluate_interpolated_trajectory=evaluate_interpolated_trajectory,
fixed_iters=fixed_iters_trajopt,
store_debug=store_trajopt_debug,
collision_activation_distance=collision_activation_distance,
trajopt_dt=js_trajopt_dt,
store_debug_in_result=store_debug_in_result,
smooth_weight=smooth_weight,
cspace_threshold=cspace_threshold,
state_finite_difference_mode=state_finite_difference_mode,
minimize_jerk=minimize_jerk,
filter_robot_command=filter_robot_command,
optimize_dt=optimize_dt,
num_seeds=num_trajopt_noisy_seeds,
)
finetune_js_trajopt_solver = TrajOptSolver(finetune_js_trajopt_cfg)
if graph_trajopt_iters is not None:
graph_trajopt_iters = math.ceil(
graph_trajopt_iters / finetune_trajopt_solver.solver.newton_optimizer.inner_iters
@@ -823,6 +869,7 @@ class MotionGenConfig:
graph_planner,
trajopt_solver=trajopt_solver,
js_trajopt_solver=js_trajopt_solver,
finetune_js_trajopt_solver=finetune_js_trajopt_solver,
finetune_trajopt_solver=finetune_trajopt_solver,
interpolation_type=interpolation_type,
interpolation_steps=interpolation_steps,
@@ -951,6 +998,9 @@ class MotionGenPlanConfig:
#: check for joint limits, self-collision, and collision with the world.
check_start_validity: bool = True
#: Finetune dt scale for joint space planning.
finetune_js_dt_scale: Optional[float] = 1.1
def __post_init__(self):
"""Post initialization checks."""
if not self.enable_opt and not self.enable_graph:
@@ -983,6 +1033,7 @@ class MotionGenPlanConfig:
finetune_dt_scale=self.finetune_dt_scale,
finetune_attempts=self.finetune_attempts,
time_dilation_factor=self.time_dilation_factor,
finetune_js_dt_scale=self.finetune_js_dt_scale,
)
@@ -1760,6 +1811,7 @@ class MotionGen(MotionGenConfig):
self.ik_solver.reset_seed()
self.graph_planner.reset_seed()
self.trajopt_solver.reset_seed()
self.js_trajopt_solver.reset_seed()
def get_retract_config(self) -> T_DOF:
"""Returns the retract/home configuration of the robot."""
@@ -1803,7 +1855,12 @@ class MotionGen(MotionGenConfig):
goal_state = start_state.clone()
goal_state.position[..., warmup_joint_index] += warmup_joint_delta
for _ in range(3):
self.plan_single_js(start_state, goal_state, MotionGenPlanConfig(max_attempts=1))
self.plan_single_js(
start_state.clone(),
goal_state.clone(),
MotionGenPlanConfig(max_attempts=1, enable_finetune_trajopt=True),
)
if enable_graph:
start_state = JointState.from_position(
self.rollout_fn.dynamics_model.retract_config.view(1, -1).clone(),
@@ -1983,7 +2040,7 @@ class MotionGen(MotionGenConfig):
"finetune_time": 0,
}
result = None
goal = Goal(goal_state=goal_state, current_state=start_state)
# goal = Goal(goal_state=goal_state, current_state=start_state)
solve_state = ReacherSolveState(
ReacherSolveType.SINGLE,
num_ik_seeds=1,
@@ -2009,7 +2066,7 @@ class MotionGen(MotionGenConfig):
result = self._plan_js_from_solve_state(
solve_state, start_state, goal_state, plan_config=plan_config
)
time_dict["trajopt_time"] += result.solve_time
time_dict["trajopt_time"] += result.trajopt_time
time_dict["graph_time"] += result.graph_time
time_dict["finetune_time"] += result.finetune_time
time_dict["trajopt_attempts"] = n
@@ -2160,6 +2217,7 @@ class MotionGen(MotionGenConfig):
+ self.trajopt_solver.get_all_rollout_instances()
+ self.finetune_trajopt_solver.get_all_rollout_instances()
+ self.js_trajopt_solver.get_all_rollout_instances()
+ self.finetune_js_trajopt_solver.get_all_rollout_instances()
)
return self._rollout_list
@@ -2171,6 +2229,7 @@ class MotionGen(MotionGenConfig):
+ self.trajopt_solver.solver.get_all_rollout_instances()
+ self.finetune_trajopt_solver.solver.get_all_rollout_instances()
+ self.js_trajopt_solver.solver.get_all_rollout_instances()
+ self.finetune_js_trajopt_solver.solver.get_all_rollout_instances()
)
return self._solver_rollout_list
@@ -2529,6 +2588,11 @@ class MotionGen(MotionGenConfig):
"""Check if the pose cost metric is projected to goal frame."""
return self.trajopt_solver.rollout_fn.goal_cost.project_distance
@property
def joint_names(self) -> List[str]:
"""Get the joint names of the robot."""
return self.rollout_fn.joint_names
def update_interpolation_type(
self,
interpolation_type: InterpolateType,
@@ -2548,6 +2612,7 @@ class MotionGen(MotionGenConfig):
self.trajopt_solver.interpolation_type = interpolation_type
self.finetune_trajopt_solver.interpolation_type = interpolation_type
self.js_trajopt_solver.interpolation_type = interpolation_type
self.finetune_js_trajopt_solver.interpolation_type = interpolation_type
def update_locked_joints(
self, lock_joints: Dict[str, float], robot_config_dict: Union[str, Dict[Any]]
@@ -3375,8 +3440,9 @@ class MotionGen(MotionGenConfig):
seed_override = solve_state.num_trajopt_seeds * self.noisy_trajopt_seeds
finetune_time = 0
newton_iters = None
for k in range(plan_config.finetune_attempts):
newton_iters = None
scaled_dt = torch.clamp(
opt_dt
@@ -3396,7 +3462,7 @@ class MotionGen(MotionGenConfig):
newton_iters=newton_iters,
)
finetune_time += traj_result.solve_time
if torch.count_nonzero(traj_result.success) > 0:
if torch.count_nonzero(traj_result.success) > 0 or not self.optimize_dt:
break
seed_traj = traj_result.optimized_seeds.detach().clone()
newton_iters = 4
@@ -3592,7 +3658,7 @@ class MotionGen(MotionGenConfig):
solve_state,
trajopt_seed_traj,
num_seeds_override=solve_state.num_trajopt_seeds,
newton_iters=trajopt_newton_iters + 2,
newton_iters=trajopt_newton_iters,
return_all_solutions=plan_config.enable_finetune_trajopt,
trajopt_instance=self.js_trajopt_solver,
)
@@ -3605,29 +3671,42 @@ class MotionGen(MotionGenConfig):
# run finetune
if plan_config.enable_finetune_trajopt and torch.count_nonzero(traj_result.success) > 0:
with profiler.record_function("motion_gen/finetune_trajopt"):
seed_traj = traj_result.raw_action.clone() # solution.position.clone()
seed_traj = seed_traj.contiguous()
seed_traj = traj_result.raw_action.clone()
og_solve_time = traj_result.solve_time
opt_dt = traj_result.optimized_dt
opt_dt = torch.min(opt_dt[traj_result.success])
finetune_time = 0
newton_iters = None
for k in range(plan_config.finetune_attempts):
scaled_dt = torch.clamp(
opt_dt
* plan_config.finetune_js_dt_scale
* (plan_config.finetune_dt_decay ** (k)),
self.js_trajopt_solver.minimum_trajectory_dt,
)
scaled_dt = torch.clamp(
torch.max(traj_result.optimized_dt[traj_result.success]),
self.trajopt_solver.minimum_trajectory_dt,
)
og_dt = self.js_trajopt_solver.solver_dt.clone()
self.js_trajopt_solver.update_solver_dt(scaled_dt.item())
traj_result = self._solve_trajopt_from_solve_state(
goal,
solve_state,
seed_traj,
trajopt_instance=self.js_trajopt_solver,
num_seeds_override=solve_state.num_trajopt_seeds,
newton_iters=trajopt_newton_iters + 4,
)
self.js_trajopt_solver.update_solver_dt(og_dt)
if self.optimize_dt:
self.finetune_js_trajopt_solver.update_solver_dt(scaled_dt.item())
traj_result = self._solve_trajopt_from_solve_state(
goal,
solve_state,
seed_traj,
trajopt_instance=self.finetune_js_trajopt_solver,
num_seeds_override=solve_state.num_trajopt_seeds,
newton_iters=newton_iters,
return_all_solutions=False,
)
result.finetune_time = traj_result.solve_time
finetune_time += traj_result.solve_time
if torch.count_nonzero(traj_result.success) > 0 or not self.optimize_dt:
break
traj_result.solve_time = og_solve_time
seed_traj = traj_result.optimized_seeds.detach().clone()
newton_iters = 4
result.finetune_time = finetune_time
traj_result.solve_time = og_solve_time
if self.store_debug_in_result:
result.debug_info["finetune_trajopt_result"] = traj_result
if torch.count_nonzero(traj_result.success) == 0:
@@ -3638,7 +3717,6 @@ class MotionGen(MotionGenConfig):
result.trajopt_time = traj_result.solve_time
result.trajopt_attempts = 1
result.success = traj_result.success
result.interpolated_plan = traj_result.interpolated_solution.trim_trajectory(
0, traj_result.path_buffer_last_tstep[0]
)