# # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. # # Standard Library import math import time from dataclasses import dataclass from typing import Any, Dict, List, Optional, Sequence, Union # Third Party import torch import torch.autograd.profiler as profiler # CuRobo from curobo.cuda_robot_model.cuda_robot_model import CudaRobotModel, CudaRobotModelState from curobo.geom.sdf.utils import create_collision_checker from curobo.geom.sdf.world import CollisionCheckerType, WorldCollision, WorldCollisionConfig from curobo.geom.types import WorldConfig from curobo.opt.newton.lbfgs import LBFGSOpt, LBFGSOptConfig from curobo.opt.newton.newton_base import NewtonOptBase, NewtonOptConfig from curobo.opt.particle.parallel_es import ParallelES, ParallelESConfig from curobo.opt.particle.parallel_mppi import ParallelMPPI, ParallelMPPIConfig from curobo.rollout.arm_reacher import ArmReacher, ArmReacherConfig from curobo.rollout.cost.pose_cost import PoseCostMetric from curobo.rollout.dynamics_model.integration_utils import ( action_interpolate_kernel, interpolate_kernel, ) from curobo.rollout.rollout_base import Goal, RolloutBase, RolloutMetrics from curobo.types.base import TensorDeviceType from curobo.types.robot import JointState, RobotConfig from curobo.types.tensor import T_BDOF, T_DOF, T_BValue_bool, T_BValue_float from curobo.util.helpers import list_idx_if_not_none from curobo.util.logger import log_error, log_info, log_warn from curobo.util.torch_utils import get_torch_jit_decorator, is_torch_compile_available from curobo.util.trajectory import ( InterpolateType, calculate_dt_no_clamp, get_batch_interpolated_trajectory, ) from curobo.util_file import get_robot_configs_path, get_task_configs_path, join_path, load_yaml from curobo.wrap.reacher.evaluator import TrajEvaluator, TrajEvaluatorConfig from curobo.wrap.reacher.types import ReacherSolveState, ReacherSolveType from curobo.wrap.wrap_base import WrapBase, WrapConfig, WrapResult @dataclass class TrajOptSolverConfig: robot_config: RobotConfig solver: WrapBase rollout_fn: ArmReacher position_threshold: float rotation_threshold: float traj_tsteps: int use_cspace_seed: bool = True interpolation_type: InterpolateType = InterpolateType.LINEAR_CUDA interpolation_steps: int = 1000 world_coll_checker: Optional[WorldCollision] = None seed_ratio: Optional[Dict[str, int]] = None num_seeds: int = 1 bias_node: Optional[T_DOF] = None interpolation_dt: float = 0.01 traj_evaluator_config: TrajEvaluatorConfig = TrajEvaluatorConfig() traj_evaluator: Optional[TrajEvaluator] = None evaluate_interpolated_trajectory: bool = True cspace_threshold: float = 0.1 tensor_args: TensorDeviceType = TensorDeviceType() sync_cuda_time: bool = True interpolate_rollout: Optional[ArmReacher] = None use_cuda_graph_metrics: bool = False trim_steps: Optional[List[int]] = None store_debug_in_result: bool = False optimize_dt: bool = True use_cuda_graph: bool = True @staticmethod @profiler.record_function("trajopt_config/load_from_robot_config") def load_from_robot_config( robot_cfg: Union[str, Dict, RobotConfig], world_model: Optional[ Union[Union[List[Dict], List[WorldConfig]], Union[Dict, WorldConfig]] ] = None, tensor_args: TensorDeviceType = TensorDeviceType(), position_threshold: float = 0.005, rotation_threshold: float = 0.05, cspace_threshold: float = 0.05, world_coll_checker=None, base_cfg_file: str = "base_cfg.yml", particle_file: str = "particle_trajopt.yml", gradient_file: str = "gradient_trajopt.yml", traj_tsteps: Optional[int] = None, interpolation_type: InterpolateType = InterpolateType.LINEAR_CUDA, interpolation_steps: int = 10000, interpolation_dt: float = 0.01, use_cuda_graph: bool = True, self_collision_check: bool = False, self_collision_opt: bool = True, grad_trajopt_iters: Optional[int] = None, num_seeds: int = 2, seed_ratio: Dict[str, int] = {"linear": 1.0, "bias": 0.0, "start": 0.0, "end": 0.0}, use_particle_opt: bool = True, collision_checker_type: Optional[CollisionCheckerType] = CollisionCheckerType.MESH, traj_evaluator_config: TrajEvaluatorConfig = TrajEvaluatorConfig(), traj_evaluator: Optional[TrajEvaluator] = None, minimize_jerk: bool = True, use_gradient_descent: bool = False, collision_cache: Optional[Dict[str, int]] = None, n_collision_envs: Optional[int] = None, use_es: Optional[bool] = None, es_learning_rate: Optional[float] = 0.1, use_fixed_samples: Optional[bool] = None, aux_rollout: Optional[ArmReacher] = None, evaluate_interpolated_trajectory: bool = True, fixed_iters: Optional[bool] = None, store_debug: bool = False, sync_cuda_time: bool = True, collision_activation_distance: Optional[float] = None, trajopt_dt: Optional[float] = None, trim_steps: Optional[List[int]] = None, store_debug_in_result: bool = False, smooth_weight: Optional[List[float]] = None, state_finite_difference_mode: Optional[str] = None, filter_robot_command: bool = False, optimize_dt: bool = True, project_pose_to_goal_frame: bool = True, ): # NOTE: Don't have default optimize_dt, instead read from a configuration file. # use default values, disable environment collision checking if isinstance(robot_cfg, str): robot_cfg = load_yaml(join_path(get_robot_configs_path(), robot_cfg))["robot_cfg"] if isinstance(robot_cfg, dict): robot_cfg = RobotConfig.from_dict(robot_cfg, tensor_args) base_config_data = load_yaml(join_path(get_task_configs_path(), base_cfg_file)) if collision_cache is not None: base_config_data["world_collision_checker_cfg"]["cache"] = collision_cache if n_collision_envs is not None: base_config_data["world_collision_checker_cfg"]["n_envs"] = n_collision_envs if not self_collision_check: base_config_data["constraint"]["self_collision_cfg"]["weight"] = 0.0 self_collision_opt = False if not minimize_jerk: filter_robot_command = False if collision_checker_type is not None: base_config_data["world_collision_checker_cfg"]["checker_type"] = collision_checker_type if world_coll_checker is None and world_model is not None: world_cfg = WorldCollisionConfig.load_from_dict( base_config_data["world_collision_checker_cfg"], world_model, tensor_args ) world_coll_checker = create_collision_checker(world_cfg) config_data = load_yaml(join_path(get_task_configs_path(), particle_file)) grad_config_data = load_yaml(join_path(get_task_configs_path(), gradient_file)) if traj_tsteps is None: traj_tsteps = grad_config_data["model"]["horizon"] base_config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame base_config_data["convergence"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame grad_config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame base_config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame base_config_data["convergence"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame grad_config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame config_data["model"]["horizon"] = traj_tsteps grad_config_data["model"]["horizon"] = traj_tsteps if minimize_jerk is not None: if not minimize_jerk: grad_config_data["cost"]["bound_cfg"]["smooth_weight"][2] = 0.0 grad_config_data["cost"]["bound_cfg"]["smooth_weight"][1] *= 2.0 grad_config_data["lbfgs"]["cost_delta_threshold"] = 0.1 if minimize_jerk and grad_config_data["cost"]["bound_cfg"]["smooth_weight"][2] == 0.0: log_warn("minimize_jerk flag is enabled but weight term is zero") if state_finite_difference_mode is not None: config_data["model"]["state_finite_difference_mode"] = state_finite_difference_mode grad_config_data["model"]["state_finite_difference_mode"] = state_finite_difference_mode config_data["model"]["filter_robot_command"] = filter_robot_command grad_config_data["model"]["filter_robot_command"] = filter_robot_command if collision_activation_distance is not None: config_data["cost"]["primitive_collision_cfg"][ "activation_distance" ] = collision_activation_distance grad_config_data["cost"]["primitive_collision_cfg"][ "activation_distance" ] = collision_activation_distance if grad_trajopt_iters is not None: grad_config_data["lbfgs"]["n_iters"] = grad_trajopt_iters grad_config_data["lbfgs"]["cold_start_n_iters"] = grad_trajopt_iters if use_fixed_samples is not None: config_data["mppi"]["sample_params"]["fixed_samples"] = use_fixed_samples if smooth_weight is not None: grad_config_data["cost"]["bound_cfg"]["smooth_weight"][ : len(smooth_weight) ] = smooth_weight # velocity if store_debug: use_cuda_graph = False fixed_iters = True grad_config_data["lbfgs"]["store_debug"] = store_debug config_data["mppi"]["store_debug"] = store_debug store_debug_in_result = True if use_cuda_graph is not None: config_data["mppi"]["use_cuda_graph"] = use_cuda_graph grad_config_data["lbfgs"]["use_cuda_graph"] = use_cuda_graph else: use_cuda_graph = grad_config_data["lbfgs"]["use_cuda_graph"] if not self_collision_opt: config_data["cost"]["self_collision_cfg"]["weight"] = 0.0 grad_config_data["cost"]["self_collision_cfg"]["weight"] = 0.0 config_data["mppi"]["n_problems"] = 1 grad_config_data["lbfgs"]["n_problems"] = 1 if fixed_iters is not None: grad_config_data["lbfgs"]["fixed_iters"] = fixed_iters grad_cfg = ArmReacherConfig.from_dict( robot_cfg, grad_config_data["model"], grad_config_data["cost"], base_config_data["constraint"], base_config_data["convergence"], base_config_data["world_collision_checker_cfg"], world_model, world_coll_checker=world_coll_checker, tensor_args=tensor_args, ) cfg = ArmReacherConfig.from_dict( robot_cfg, config_data["model"], config_data["cost"], base_config_data["constraint"], base_config_data["convergence"], base_config_data["world_collision_checker_cfg"], world_model, world_coll_checker=world_coll_checker, tensor_args=tensor_args, ) safety_robot_model = robot_cfg.kinematics safety_robot_cfg = RobotConfig(**vars(robot_cfg)) safety_robot_cfg.kinematics = safety_robot_model safety_cfg = ArmReacherConfig.from_dict( safety_robot_cfg, config_data["model"], config_data["cost"], base_config_data["constraint"], base_config_data["convergence"], base_config_data["world_collision_checker_cfg"], world_model, world_coll_checker=world_coll_checker, tensor_args=tensor_args, ) arm_rollout_mppi = None with profiler.record_function("trajopt_config/create_rollouts"): if use_particle_opt: arm_rollout_mppi = ArmReacher(cfg) arm_rollout_grad = ArmReacher(grad_cfg) arm_rollout_safety = ArmReacher(safety_cfg) if aux_rollout is None: aux_rollout = ArmReacher(safety_cfg) interpolate_rollout = ArmReacher(safety_cfg) if trajopt_dt is not None: if arm_rollout_mppi is not None: arm_rollout_mppi.update_traj_dt(dt=trajopt_dt) aux_rollout.update_traj_dt(dt=trajopt_dt) arm_rollout_grad.update_traj_dt(dt=trajopt_dt) arm_rollout_safety.update_traj_dt(dt=trajopt_dt) if arm_rollout_mppi is not None: config_dict = ParallelMPPIConfig.create_data_dict( config_data["mppi"], arm_rollout_mppi, tensor_args ) parallel_mppi = None if use_es is not None and use_es: mppi_cfg = ParallelESConfig(**config_dict) if es_learning_rate is not None: mppi_cfg.learning_rate = es_learning_rate parallel_mppi = ParallelES(mppi_cfg) elif use_particle_opt: mppi_cfg = ParallelMPPIConfig(**config_dict) parallel_mppi = ParallelMPPI(mppi_cfg) config_dict = LBFGSOptConfig.create_data_dict( grad_config_data["lbfgs"], arm_rollout_grad, tensor_args ) lbfgs_cfg = LBFGSOptConfig(**config_dict) if use_gradient_descent: newton_keys = NewtonOptConfig.__dataclass_fields__.keys() newton_d = {} lbfgs_k = vars(lbfgs_cfg) for k in newton_keys: newton_d[k] = lbfgs_k[k] newton_d["step_scale"] = 0.9 newton_cfg = NewtonOptConfig(**newton_d) lbfgs = NewtonOptBase(newton_cfg) else: lbfgs = LBFGSOpt(lbfgs_cfg) if use_particle_opt: opt_list = [parallel_mppi] else: opt_list = [] opt_list.append(lbfgs) cfg = WrapConfig( safety_rollout=arm_rollout_safety, optimizers=opt_list, compute_metrics=True, # not evaluate_interpolated_trajectory, use_cuda_graph_metrics=grad_config_data["lbfgs"]["use_cuda_graph"], sync_cuda_time=sync_cuda_time, ) trajopt = WrapBase(cfg) trajopt_cfg = TrajOptSolverConfig( robot_config=robot_cfg, rollout_fn=aux_rollout, solver=trajopt, position_threshold=position_threshold, rotation_threshold=rotation_threshold, cspace_threshold=cspace_threshold, traj_tsteps=traj_tsteps, interpolation_steps=interpolation_steps, interpolation_dt=interpolation_dt, interpolation_type=interpolation_type, world_coll_checker=world_coll_checker, bias_node=robot_cfg.kinematics.cspace.retract_config, seed_ratio=seed_ratio, num_seeds=num_seeds, traj_evaluator_config=traj_evaluator_config, traj_evaluator=traj_evaluator, evaluate_interpolated_trajectory=evaluate_interpolated_trajectory, tensor_args=tensor_args, sync_cuda_time=sync_cuda_time, interpolate_rollout=interpolate_rollout, use_cuda_graph_metrics=use_cuda_graph, trim_steps=trim_steps, store_debug_in_result=store_debug_in_result, optimize_dt=optimize_dt, use_cuda_graph=use_cuda_graph, ) return trajopt_cfg @dataclass class TrajResult(Sequence): success: T_BValue_bool goal: Goal solution: JointState seed: T_BDOF solve_time: float debug_info: Optional[Any] = None metrics: Optional[RolloutMetrics] = None interpolated_solution: Optional[JointState] = None path_buffer_last_tstep: Optional[List[int]] = None position_error: Optional[T_BValue_float] = None rotation_error: Optional[T_BValue_float] = None cspace_error: Optional[T_BValue_float] = None smooth_error: Optional[T_BValue_float] = None smooth_label: Optional[T_BValue_bool] = None optimized_dt: Optional[torch.Tensor] = None raw_solution: Optional[JointState] = None raw_action: Optional[torch.Tensor] = None goalset_index: Optional[torch.Tensor] = None def __getitem__(self, idx): # position_error = rotation_error = cspace_error = path_buffer_last_tstep = None # metrics = interpolated_solution = None d_list = [ self.interpolated_solution, self.metrics, self.path_buffer_last_tstep, self.position_error, self.rotation_error, self.cspace_error, self.goalset_index, ] idx_vals = list_idx_if_not_none(d_list, idx) return TrajResult( goal=self.goal[idx], solution=self.solution[idx], success=self.success[idx], seed=self.seed[idx], debug_info=self.debug_info, solve_time=self.solve_time, interpolated_solution=idx_vals[0], metrics=idx_vals[1], path_buffer_last_tstep=idx_vals[2], position_error=idx_vals[3], rotation_error=idx_vals[4], cspace_error=idx_vals[5], goalset_index=idx_vals[6], ) def __len__(self): return self.success.shape[0] class TrajOptSolver(TrajOptSolverConfig): def __init__(self, config: TrajOptSolverConfig) -> None: super().__init__(**vars(config)) self.dof = self.rollout_fn.d_action self.action_horizon = self.rollout_fn.action_horizon self.delta_vec = interpolate_kernel(2, self.action_horizon, self.tensor_args).unsqueeze(0) self.waypoint_delta_vec = interpolate_kernel( 3, int(self.action_horizon / 2), self.tensor_args ).unsqueeze(0) assert self.action_horizon / 2 != 0.0 self.solver.update_nproblems(self.num_seeds) self._max_joint_vel = ( self.solver.safety_rollout.state_bounds.velocity.view(2, self.dof)[1, :].reshape( 1, 1, self.dof ) ) - 0.02 self._max_joint_acc = self.rollout_fn.state_bounds.acceleration[1, :] - 0.02 self._max_joint_jerk = self.rollout_fn.state_bounds.jerk[1, :] - 0.02 self._num_seeds = -1 self._col = None if self.traj_evaluator is None: self.traj_evaluator = TrajEvaluator(self.traj_evaluator_config) self._interpolation_dt_tensor = self.tensor_args.to_device([self.interpolation_dt]) self._n_seeds = self._get_seed_numbers(self.num_seeds) self._goal_buffer = None self._solve_state = None self._velocity_bounds = self.solver.rollout_fn.state_bounds.velocity[1] self._og_newton_iters = self.solver.optimizers[-1].outer_iters self._og_newton_fixed_iters = self.solver.optimizers[-1].fixed_iters self._interpolated_traj_buffer = None self._kin_list = None self._rollout_list = None def get_all_rollout_instances(self) -> List[RolloutBase]: if self._rollout_list is None: self._rollout_list = [ self.rollout_fn, self.interpolate_rollout, ] + self.solver.get_all_rollout_instances() return self._rollout_list def get_all_kinematics_instances(self) -> List[CudaRobotModel]: if self._kin_list is None: self._kin_list = [ i.dynamics_model.robot_model for i in self.get_all_rollout_instances() ] return self._kin_list def attach_object_to_robot( self, sphere_radius: float, sphere_tensor: Optional[torch.Tensor] = None, link_name: str = "attached_object", ) -> None: for k in self.get_all_kinematics_instances(): k.attach_object( sphere_radius=sphere_radius, sphere_tensor=sphere_tensor, link_name=link_name ) def detach_object_from_robot(self, link_name: str = "attached_object") -> None: for k in self.get_all_kinematics_instances(): k.detach_object(link_name) def update_goal_buffer( self, solve_state: ReacherSolveState, goal: Goal, ): self._solve_state, self._goal_buffer, update_reference = solve_state.update_goal( goal, self._solve_state, self._goal_buffer, self.tensor_args, ) if update_reference: if self.use_cuda_graph and self._col is not None: log_error("changing goal type, breaking previous cuda graph.") self.reset_cuda_graph() self.solver.update_nproblems(self._solve_state.get_batch_size()) self._col = torch.arange( 0, goal.batch, device=self.tensor_args.device, dtype=torch.long ) self.reset_shape() return self._goal_buffer def solve_any( self, solve_type: ReacherSolveType, goal: Goal, seed_traj: Optional[JointState] = None, use_nn_seed: bool = False, return_all_solutions: bool = False, num_seeds: Optional[int] = None, seed_success: Optional[torch.Tensor] = None, newton_iters: Optional[int] = None, ) -> TrajResult: if solve_type == ReacherSolveType.SINGLE: return self.solve_single( goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, newton_iters=newton_iters, ) elif solve_type == ReacherSolveType.GOALSET: return self.solve_goalset( goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, newton_iters=newton_iters, ) elif solve_type == ReacherSolveType.BATCH: return self.solve_batch( goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, seed_success, newton_iters=newton_iters, ) elif solve_type == ReacherSolveType.BATCH_GOALSET: return self.solve_batch_goalset( goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, seed_success, newton_iters=newton_iters, ) elif solve_type == ReacherSolveType.BATCH_ENV: return self.solve_batch_env( goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, seed_success, newton_iters=newton_iters, ) elif solve_type == ReacherSolveType.BATCH_ENV_GOALSET: return self.solve_batch_env_goalset( goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, seed_success, newton_iters=newton_iters, ) def solve_from_solve_state( self, solve_state: ReacherSolveState, goal: Goal, seed_traj: Optional[JointState] = None, use_nn_seed: bool = False, return_all_solutions: bool = False, num_seeds: Optional[int] = None, seed_success: Optional[torch.Tensor] = None, newton_iters: Optional[int] = None, ): if solve_state.batch_env: if solve_state.batch_size > self.world_coll_checker.n_envs: raise ValueError("Batch Env is less that goal batch") if newton_iters is not None: self.solver.newton_optimizer.outer_iters = newton_iters self.solver.newton_optimizer.fixed_iters = True goal_buffer = self.update_goal_buffer(solve_state, goal) # if self.evaluate_interpolated_trajectory: self.interpolate_rollout.update_params(goal_buffer) # get seeds: seed_traj = self.get_seed_set( goal_buffer, seed_traj, seed_success, num_seeds, solve_state.batch_mode ) # remove goal state if goal pose is not None: if goal_buffer.goal_pose.position is not None: goal_buffer.goal_state = None self.solver.reset() result = self.solver.solve(goal_buffer, seed_traj) log_info("Ran TO") traj_result = self._get_result( result, return_all_solutions, goal_buffer, seed_traj, num_seeds, solve_state.batch_mode, ) if traj_result.goalset_index is not None: traj_result.goalset_index[traj_result.goalset_index >= goal.goal_pose.n_goalset] = 0 if newton_iters is not None: self.solver.newton_optimizer.outer_iters = self._og_newton_iters self.solver.newton_optimizer.fixed_iters = self._og_newton_fixed_iters return traj_result def solve_single( self, goal: Goal, seed_traj: Optional[JointState] = None, use_nn_seed: bool = False, return_all_solutions: bool = False, num_seeds: Optional[int] = None, newton_iters: Optional[int] = None, ) -> TrajResult: if num_seeds is None: num_seeds = self.num_seeds solve_state = ReacherSolveState( ReacherSolveType.SINGLE, num_trajopt_seeds=num_seeds, batch_size=1, n_envs=1, n_goalset=1, ) return self.solve_from_solve_state( solve_state, goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, newton_iters=newton_iters, ) def solve_goalset( self, goal: Goal, seed_traj: Optional[JointState] = None, use_nn_seed: bool = False, return_all_solutions: bool = False, num_seeds: Optional[int] = None, newton_iters: Optional[int] = None, ) -> TrajResult: if num_seeds is None: num_seeds = self.num_seeds solve_state = ReacherSolveState( ReacherSolveType.GOALSET, num_trajopt_seeds=num_seeds, batch_size=1, n_envs=1, n_goalset=goal.n_goalset, ) return self.solve_from_solve_state( solve_state, goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, newton_iters=newton_iters, ) def solve_batch( self, goal: Goal, seed_traj: Optional[JointState] = None, use_nn_seed: bool = False, return_all_solutions: bool = False, num_seeds: Optional[int] = None, seed_success: Optional[torch.Tensor] = None, newton_iters: Optional[int] = None, ) -> TrajResult: if num_seeds is None: num_seeds = self.num_seeds solve_state = ReacherSolveState( ReacherSolveType.BATCH, num_trajopt_seeds=num_seeds, batch_size=goal.batch, n_envs=1, n_goalset=1, ) return self.solve_from_solve_state( solve_state, goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, seed_success, newton_iters=newton_iters, ) def solve_batch_goalset( self, goal: Goal, seed_traj: Optional[JointState] = None, use_nn_seed: bool = False, return_all_solutions: bool = False, num_seeds: Optional[int] = None, seed_success: Optional[torch.Tensor] = None, newton_iters: Optional[int] = None, ) -> TrajResult: if num_seeds is None: num_seeds = self.num_seeds solve_state = ReacherSolveState( ReacherSolveType.BATCH_GOALSET, num_trajopt_seeds=num_seeds, batch_size=goal.batch, n_envs=1, n_goalset=goal.n_goalset, ) return self.solve_from_solve_state( solve_state, goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, newton_iters=newton_iters, ) def solve_batch_env( self, goal: Goal, seed_traj: Optional[JointState] = None, use_nn_seed: bool = False, return_all_solutions: bool = False, num_seeds: Optional[int] = None, seed_success: Optional[torch.Tensor] = None, newton_iters: Optional[int] = None, ) -> TrajResult: if num_seeds is None: num_seeds = self.num_seeds solve_state = ReacherSolveState( ReacherSolveType.BATCH_ENV, num_trajopt_seeds=num_seeds, batch_size=goal.batch, n_envs=goal.batch, n_goalset=1, ) return self.solve_from_solve_state( solve_state, goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, seed_success, newton_iters=newton_iters, ) def solve_batch_env_goalset( self, goal: Goal, seed_traj: Optional[JointState] = None, use_nn_seed: bool = False, return_all_solutions: bool = False, num_seeds: Optional[int] = None, seed_success: Optional[torch.Tensor] = None, newton_iters: Optional[int] = None, ) -> TrajResult: if num_seeds is None: num_seeds = self.num_seeds solve_state = ReacherSolveState( ReacherSolveType.BATCH_ENV_GOALSET, num_trajopt_seeds=num_seeds, batch_size=goal.batch, n_envs=goal.batch, n_goalset=goal.n_goalset, ) return self.solve_from_solve_state( solve_state, goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, newton_iters=newton_iters, ) def solve( self, goal: Goal, seed_traj: Optional[JointState] = None, use_nn_seed: bool = False, return_all_solutions: bool = False, num_seeds: Optional[int] = None, newton_iters: Optional[int] = None, ) -> TrajResult: """Only for single goal Args: goal (Goal): _description_ seed_traj (Optional[JointState], optional): _description_. Defaults to None. use_nn_seed (bool, optional): _description_. Defaults to False. Raises: NotImplementedError: _description_ Returns: TrajResult: _description_ """ log_warn("TrajOpt.solve() is deprecated, use TrajOpt.solve_single or others instead") if goal.goal_pose.batch == 1 and goal.goal_pose.n_goalset == 1: return self.solve_single( goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, newton_iters=newton_iters, ) if goal.goal_pose.batch == 1 and goal.goal_pose.n_goalset > 1: return self.solve_goalset( goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, newton_iters=newton_iters, ) if goal.goal_pose.batch > 1 and goal.goal_pose.n_goalset == 1: return self.solve_batch( goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, newton_iters=newton_iters, ) raise NotImplementedError() @profiler.record_function("trajopt/get_result") def _get_result( self, result: WrapResult, return_all_solutions: bool, goal: Goal, seed_traj: JointState, num_seeds: int, batch_mode: bool = False, ): st_time = time.time() if self.trim_steps is not None: result.action = result.action.trim_trajectory(self.trim_steps[0], self.trim_steps[1]) interpolated_trajs, last_tstep, opt_dt = self.get_interpolated_trajectory(result.action) if self.sync_cuda_time: torch.cuda.synchronize() interpolation_time = time.time() - st_time if self.evaluate_interpolated_trajectory: with profiler.record_function("trajopt/evaluate_interpolated"): if self.use_cuda_graph_metrics: metrics = self.interpolate_rollout.get_metrics_cuda_graph(interpolated_trajs) else: metrics = self.interpolate_rollout.get_metrics(interpolated_trajs) result.metrics.feasible = metrics.feasible result.metrics.position_error = metrics.position_error result.metrics.rotation_error = metrics.rotation_error result.metrics.cspace_error = metrics.cspace_error result.metrics.goalset_index = metrics.goalset_index st_time = time.time() if result.metrics.cspace_error is None and result.metrics.position_error is None: raise log_error("convergence check requires either goal_pose or goal_state") success = jit_feasible_success( result.metrics.feasible, result.metrics.position_error, result.metrics.rotation_error, result.metrics.cspace_error, self.position_threshold, self.rotation_threshold, self.cspace_threshold, ) if False: feasible = torch.all(result.metrics.feasible, dim=-1) if result.metrics.position_error is not None: converge = torch.logical_and( result.metrics.position_error[..., -1] <= self.position_threshold, result.metrics.rotation_error[..., -1] <= self.rotation_threshold, ) elif result.metrics.cspace_error is not None: converge = result.metrics.cspace_error[..., -1] <= self.cspace_threshold else: raise ValueError("convergence check requires either goal_pose or goal_state") success = torch.logical_and(feasible, converge) if return_all_solutions: traj_result = TrajResult( success=success, goal=goal, solution=result.action.scale_by_dt(self.solver_dt_tensor, opt_dt.view(-1, 1, 1)), seed=seed_traj, position_error=result.metrics.position_error, rotation_error=result.metrics.rotation_error, solve_time=result.solve_time, metrics=result.metrics, # TODO: index this also interpolated_solution=interpolated_trajs, debug_info={"solver": result.debug, "interpolation_time": interpolation_time}, path_buffer_last_tstep=last_tstep, cspace_error=result.metrics.cspace_error, optimized_dt=opt_dt, raw_solution=result.action, raw_action=result.raw_action, goalset_index=result.metrics.goalset_index, ) else: # get path length: if self.evaluate_interpolated_trajectory: smooth_label, smooth_cost = self.traj_evaluator.evaluate_interpolated_smootheness( interpolated_trajs, opt_dt, self.solver.rollout_fn.dynamics_model.cspace_distance_weight, self._velocity_bounds, ) else: smooth_label, smooth_cost = self.traj_evaluator.evaluate( result.action, self.solver.rollout_fn.traj_dt, self.solver.rollout_fn.dynamics_model.cspace_distance_weight, self._velocity_bounds, ) with profiler.record_function("trajopt/best_select"): if True: # not get_torch_jit_decorator() == torch.jit.script: # This only works if torch compile is available: ( idx, position_error, rotation_error, cspace_error, goalset_index, opt_dt, success, ) = jit_trajopt_best_select( success, smooth_label, result.metrics.cspace_error, result.metrics.pose_error, result.metrics.position_error, result.metrics.rotation_error, result.metrics.goalset_index, result.metrics.cost, smooth_cost, batch_mode, goal.batch, num_seeds, self._col, opt_dt, ) if batch_mode: last_tstep = [last_tstep[i] for i in idx] else: last_tstep = [last_tstep[idx.item()]] best_act_seq = result.action[idx] best_raw_action = result.raw_action[idx] interpolated_traj = interpolated_trajs[idx] else: success[~smooth_label] = False # get the best solution: if result.metrics.pose_error is not None: convergence_error = result.metrics.pose_error[..., -1] elif result.metrics.cspace_error is not None: convergence_error = result.metrics.cspace_error[..., -1] else: raise ValueError( "convergence check requires either goal_pose or goal_state" ) running_cost = torch.mean(result.metrics.cost, dim=-1) * 0.0001 error = convergence_error + smooth_cost + running_cost error[~success] += 10000.0 if batch_mode: idx = torch.argmin(error.view(goal.batch, num_seeds), dim=-1) idx = idx + num_seeds * self._col last_tstep = [last_tstep[i] for i in idx] success = success[idx] else: idx = torch.argmin(error, dim=0) last_tstep = [last_tstep[idx.item()]] success = success[idx : idx + 1] best_act_seq = result.action[idx] best_raw_action = result.raw_action[idx] interpolated_traj = interpolated_trajs[idx] goalset_index = position_error = rotation_error = cspace_error = None if result.metrics.position_error is not None: position_error = result.metrics.position_error[idx, -1] if result.metrics.rotation_error is not None: rotation_error = result.metrics.rotation_error[idx, -1] if result.metrics.cspace_error is not None: cspace_error = result.metrics.cspace_error[idx, -1] if result.metrics.goalset_index is not None: goalset_index = result.metrics.goalset_index[idx, -1] opt_dt = opt_dt[idx] if self.sync_cuda_time: torch.cuda.synchronize() if len(best_act_seq.shape) == 3: opt_dt_v = opt_dt.view(-1, 1, 1) else: opt_dt_v = opt_dt.view(1, 1) opt_solution = best_act_seq.scale_by_dt(self.solver_dt_tensor, opt_dt_v) select_time = time.time() - st_time debug_info = None if self.store_debug_in_result: debug_info = { "solver": result.debug, "interpolation_time": interpolation_time, "select_time": select_time, } traj_result = TrajResult( success=success, goal=goal, solution=opt_solution, seed=seed_traj, position_error=position_error, rotation_error=rotation_error, cspace_error=cspace_error, solve_time=result.solve_time, metrics=result.metrics, # TODO: index this also interpolated_solution=interpolated_traj, debug_info=debug_info, path_buffer_last_tstep=last_tstep, smooth_error=smooth_cost, smooth_label=smooth_label, optimized_dt=opt_dt, raw_solution=best_act_seq, raw_action=best_raw_action, goalset_index=goalset_index, ) return traj_result def batch_solve( self, goal: Goal, seed_traj: Optional[JointState] = None, seed_success: Optional[torch.Tensor] = None, use_nn_seed: bool = False, return_all_solutions: bool = False, num_seeds: Optional[int] = None, ) -> TrajResult: """Only for single goal Args: goal (Goal): _description_ seed_traj (Optional[JointState], optional): _description_. Defaults to None. use_nn_seed (bool, optional): _description_. Defaults to False. Raises: NotImplementedError: _description_ Returns: TrajResult: _description_ """ log_warn("TrajOpt.solve() is deprecated, use TrajOpt.solve_single or others instead") if goal.n_goalset == 1: return self.solve_batch( goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, seed_success ) if goal.n_goalset > 1: return self.solve_batch_goalset( goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, seed_success ) def get_linear_seed(self, start_state, goal_state): start_q = start_state.position.view(-1, 1, self.dof) end_q = goal_state.position.view(-1, 1, self.dof) edges = torch.cat((start_q, end_q), dim=1) seed = self.delta_vec @ edges return seed def get_start_seed(self, start_state): start_q = start_state.position.view(-1, 1, self.dof) edges = torch.cat((start_q, start_q), dim=1) seed = self.delta_vec @ edges return seed def _get_seed_numbers(self, num_seeds): n_seeds = {"linear": 0, "bias": 0, "start": 0, "goal": 0} k = n_seeds.keys t_seeds = 0 for k in n_seeds: if k not in self.seed_ratio: continue if self.seed_ratio[k] > 0.0: n_seeds[k] = math.floor(self.seed_ratio[k] * num_seeds) t_seeds += n_seeds[k] if t_seeds < num_seeds: n_seeds["linear"] += num_seeds - t_seeds return n_seeds def get_seed_set( self, goal: Goal, seed_traj: Union[JointState, torch.Tensor, None] = None, # n_seeds,batch, h, dof seed_success: Optional[torch.Tensor] = None, # batch, n_seeds num_seeds: Optional[int] = None, batch_mode: bool = False, ): # if batch_mode: total_seeds = goal.batch * num_seeds # else: # total_seeds = num_seeds if isinstance(seed_traj, JointState): seed_traj = seed_traj.position if seed_traj is None: if goal.goal_state is not None and self.use_cspace_seed: # get linear seed seed_traj = self.get_seeds(goal.current_state, goal.goal_state, num_seeds=num_seeds) # .view(batch_size, self.num_seeds, self.action_horizon, self.dof) else: # get start repeat seed: log_info("No goal state found, using current config to seed") seed_traj = self.get_seeds( goal.current_state, goal.current_state, num_seeds=num_seeds ) elif seed_success is not None: lin_seed_traj = self.get_seeds(goal.current_state, goal.goal_state, num_seeds=num_seeds) lin_seed_traj[seed_success] = seed_traj # [seed_success] seed_traj = lin_seed_traj total_seeds = goal.batch * num_seeds elif num_seeds > seed_traj.shape[0]: new_seeds = self.get_seeds( goal.current_state, goal.goal_state, num_seeds - seed_traj.shape[0] ) seed_traj = torch.cat((seed_traj, new_seeds), dim=0) # n_seed, batch, h, dof seed_traj = seed_traj.view( total_seeds, self.action_horizon, self.dof ) # n_seeds,batch, h, dof return seed_traj def get_seeds(self, start_state, goal_state, num_seeds=None): # repeat seeds: if num_seeds is None: num_seeds = self.num_seeds n_seeds = self._n_seeds else: n_seeds = self._get_seed_numbers(num_seeds) # linear seed: batch x dof -> batch x n_seeds x dof seed_set = [] if n_seeds["linear"] > 0: linear_seed = self.get_linear_seed(start_state, goal_state) linear_seeds = linear_seed.view(1, -1, self.action_horizon, self.dof).repeat( 1, n_seeds["linear"], 1, 1 ) seed_set.append(linear_seeds) if n_seeds["bias"] > 0: bias_seed = self.get_bias_seed(start_state, goal_state) bias_seeds = bias_seed.view(1, -1, self.action_horizon, self.dof).repeat( 1, n_seeds["bias"], 1, 1 ) seed_set.append(bias_seeds) if n_seeds["start"] > 0: bias_seed = self.get_start_seed(start_state) bias_seeds = bias_seed.view(1, -1, self.action_horizon, self.dof).repeat( 1, n_seeds["start"], 1, 1 ) seed_set.append(bias_seeds) if n_seeds["goal"] > 0: bias_seed = self.get_start_seed(goal_state) bias_seeds = bias_seed.view(1, -1, self.action_horizon, self.dof).repeat( 1, n_seeds["goal"], 1, 1 ) seed_set.append(bias_seeds) all_seeds = torch.cat( seed_set, dim=1 ) # .#transpose(0,1).contiguous() # n_seed, batch, h, dof return all_seeds def get_bias_seed(self, start_state, goal_state): start_q = start_state.position.view(-1, 1, self.dof) end_q = goal_state.position.view(-1, 1, self.dof) bias_q = self.bias_node.view(-1, 1, self.dof).repeat(start_q.shape[0], 1, 1) edges = torch.cat((start_q, bias_q, end_q), dim=1) seed = self.waypoint_delta_vec @ edges # print(seed) return seed @profiler.record_function("trajopt/interpolation") def get_interpolated_trajectory(self, traj_state: JointState): # do interpolation: if ( self._interpolated_traj_buffer is None or traj_state.position.shape[0] != self._interpolated_traj_buffer.position.shape[0] ): b, _, dof = traj_state.position.shape self._interpolated_traj_buffer = JointState.zeros( (b, self.interpolation_steps, dof), self.tensor_args ) self._interpolated_traj_buffer.joint_names = self.rollout_fn.joint_names state, last_tstep, opt_dt = get_batch_interpolated_trajectory( traj_state, self.interpolation_dt, self._max_joint_vel, self._max_joint_acc, self._max_joint_jerk, self.solver_dt_tensor, kind=self.interpolation_type, tensor_args=self.tensor_args, out_traj_state=self._interpolated_traj_buffer, min_dt=self.traj_evaluator_config.min_dt, optimize_dt=self.optimize_dt, ) return state, last_tstep, opt_dt def calculate_trajectory_dt( self, trajectory: JointState, ) -> torch.Tensor: opt_dt = calculate_dt_no_clamp( trajectory.velocity, trajectory.acceleration, trajectory.jerk, self._max_joint_vel, self._max_joint_acc, self._max_joint_jerk, ) return opt_dt def reset_seed(self): self.solver.reset_seed() def reset_cuda_graph(self): self.solver.reset_cuda_graph() self.interpolate_rollout.reset_cuda_graph() self.rollout_fn.reset_cuda_graph() def reset_shape(self): self.solver.reset_shape() self.interpolate_rollout.reset_shape() self.rollout_fn.reset_shape() @property def kinematics(self) -> CudaRobotModel: return self.rollout_fn.dynamics_model.robot_model @property def retract_config(self): return self.rollout_fn.dynamics_model.retract_config.view(1, -1) def fk(self, q: torch.Tensor) -> CudaRobotModelState: return self.kinematics.get_state(q) @property def solver_dt(self): return self.solver.safety_rollout.dynamics_model.traj_dt[0] # return self.solver.safety_rollout.dynamics_model.dt_traj_params.base_dt @property def solver_dt_tensor(self): return self.solver.safety_rollout.dynamics_model.traj_dt[0] def update_solver_dt( self, dt: Union[float, torch.Tensor], base_dt: Optional[float] = None, max_dt: Optional[float] = None, base_ratio: Optional[float] = None, ): all_rollouts = self.get_all_rollout_instances() for rollout in all_rollouts: rollout.update_traj_dt(dt, base_dt, max_dt, base_ratio) def compute_metrics(self, opt_trajectory: bool, interpolated_trajectory: bool): self.solver.compute_metrics = opt_trajectory self.evaluate_interpolated_trajectory = interpolated_trajectory def get_full_js(self, active_js: JointState) -> JointState: return self.rollout_fn.get_full_dof_from_solution(active_js) def update_pose_cost_metric( self, metric: PoseCostMetric, ): rollouts = self.get_all_rollout_instances() [ rollout.update_pose_cost_metric(metric) for rollout in rollouts if isinstance(rollout, ArmReacher) ] @get_torch_jit_decorator() def jit_feasible_success( feasible, position_error: Union[torch.Tensor, None], rotation_error: Union[torch.Tensor, None], cspace_error: Union[torch.Tensor, None], position_threshold: float, rotation_threshold: float, cspace_threshold: float, ): feasible = torch.all(feasible, dim=-1) converge = feasible if position_error is not None and rotation_error is not None: converge = torch.logical_and( position_error[..., -1] <= position_threshold, rotation_error[..., -1] <= rotation_threshold, ) elif cspace_error is not None: converge = cspace_error[..., -1] <= cspace_threshold success = torch.logical_and(feasible, converge) return success @get_torch_jit_decorator(only_valid_for_compile=True) def jit_trajopt_best_select( success, smooth_label, cspace_error: Union[torch.Tensor, None], pose_error: Union[torch.Tensor, None], position_error: Union[torch.Tensor, None], rotation_error: Union[torch.Tensor, None], goalset_index: Union[torch.Tensor, None], cost, smooth_cost, batch_mode: bool, batch: int, num_seeds: int, col, opt_dt, ): success[~smooth_label] = False convergence_error = 0 # get the best solution: if pose_error is not None: convergence_error = pose_error[..., -1] elif cspace_error is not None: convergence_error = cspace_error[..., -1] running_cost = torch.mean(cost, dim=-1) * 0.0001 error = convergence_error + smooth_cost + running_cost error[~success] += 10000.0 if batch_mode: idx = torch.argmin(error.view(batch, num_seeds), dim=-1) idx = idx + num_seeds * col success = success[idx] else: idx = torch.argmin(error, dim=0) success = success[idx : idx + 1] # goalset_index = position_error = rotation_error = cspace_error = None if position_error is not None: position_error = position_error[idx, -1] if rotation_error is not None: rotation_error = rotation_error[idx, -1] if cspace_error is not None: cspace_error = cspace_error[idx, -1] if goalset_index is not None: goalset_index = goalset_index[idx, -1] opt_dt = opt_dt[idx] return idx, position_error, rotation_error, cspace_error, goalset_index, opt_dt, success