Add re-timing, minimum dt robustness
This commit is contained in:
@@ -8,7 +8,31 @@
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
"""
|
||||
This module contains :meth:`MpcSolver` that provides a high-level interface to for model
|
||||
predictive control (MPC) for reaching Cartesian poses and also joint configurations while
|
||||
avoiding obstacles. The solver uses Model Predictive Path Integral (MPPI) optimization as the
|
||||
solver. MPC only optimizes locally so the robot can get stuck near joint limits or behind
|
||||
obstacles. To generate global trajectories, use
|
||||
:py:meth:`~curobo.wrap.reacher.motion_gen.MotionGen`.
|
||||
|
||||
A python example is available at :ref:`python_mpc_example`.
|
||||
|
||||
|
||||
|
||||
.. note::
|
||||
Gradient-based MPC is also implemented with L-BFGS but is highly experimental and not
|
||||
recommended for real robots.
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<p>
|
||||
<video autoplay="True" loop="True" muted="True" preload="auto" width="100%"><source src="../videos/mpc_clip.mp4" type="video/mp4"></video>
|
||||
</p>
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Standard Library
|
||||
import time
|
||||
@@ -19,6 +43,7 @@ from typing import Dict, Optional, Union
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.cuda_robot_model.cuda_robot_model import CudaRobotModel
|
||||
from curobo.geom.sdf.utils import create_collision_checker
|
||||
from curobo.geom.sdf.world import CollisionCheckerType, WorldCollision, WorldCollisionConfig
|
||||
from curobo.geom.types import WorldConfig
|
||||
@@ -44,22 +69,31 @@ from curobo.wrap.wrap_mpc import WrapConfig, WrapMpc
|
||||
|
||||
@dataclass
|
||||
class MpcSolverConfig:
|
||||
"""Configuration dataclass for MPC."""
|
||||
|
||||
#: MPC Solver.
|
||||
solver: WrapMpc
|
||||
|
||||
#: World Collision Checker.
|
||||
world_coll_checker: Optional[WorldCollision] = None
|
||||
|
||||
#: Numeric precision and device to run computations.
|
||||
tensor_args: TensorDeviceType = TensorDeviceType()
|
||||
|
||||
#: Capture full step in MPC as a single CUDA graph. This is not supported currently.
|
||||
use_cuda_graph_full_step: bool = False
|
||||
|
||||
@staticmethod
|
||||
def load_from_robot_config(
|
||||
robot_cfg: Union[Union[str, dict], RobotConfig],
|
||||
world_cfg: Union[Union[str, dict], WorldConfig],
|
||||
world_model: Union[Union[str, dict], WorldConfig],
|
||||
base_cfg: Optional[dict] = None,
|
||||
tensor_args: TensorDeviceType = TensorDeviceType(),
|
||||
compute_metrics: bool = True,
|
||||
use_cuda_graph: Optional[bool] = None,
|
||||
particle_opt_iters: Optional[int] = None,
|
||||
self_collision_check: bool = True,
|
||||
collision_checker_type: Optional[CollisionCheckerType] = CollisionCheckerType.PRIMITIVE,
|
||||
collision_checker_type: Optional[CollisionCheckerType] = CollisionCheckerType.MESH,
|
||||
use_es: Optional[bool] = None,
|
||||
es_learning_rate: Optional[float] = 0.01,
|
||||
use_cuda_graph_metrics: bool = False,
|
||||
@@ -74,9 +108,56 @@ class MpcSolverConfig:
|
||||
use_lbfgs: bool = False,
|
||||
use_mppi: bool = True,
|
||||
):
|
||||
"""Create an MPC solver configuration from robot and world configuration.
|
||||
|
||||
Args:
|
||||
robot_cfg: Robot configuration. Can be a path to a YAML file or a dictionary or
|
||||
an instance of :class:`~curobo.types.robot.RobotConfig`.
|
||||
world_model: World configuration. Can be a path to a YAML file or a dictionary or
|
||||
an instance of :class:`~curobo.geom.types.WorldConfig`.
|
||||
base_cfg: Base configuration for the solver. This file is used to check constraints
|
||||
and convergence. If None, the default configuration from ``base_cfg.yml`` is used.
|
||||
tensor_args: Numeric precision and device to run computations.
|
||||
compute_metrics: Compute metrics on MPC step.
|
||||
use_cuda_graph: Use CUDA graph for the optimization step.
|
||||
particle_opt_iters: Number of iterations for the particle optimization.
|
||||
self_collision_check: Enable self-collision check during MPC optimization.
|
||||
collision_checker_type: Type of collision checker to use. See :ref:`world_collision`.
|
||||
use_es: Use Evolution Strategies (ES) solver for MPC. Highly experimental.
|
||||
es_learning_rate: Learning rate for ES solver.
|
||||
use_cuda_graph_metrics: Use CUDA graph for computing metrics.
|
||||
store_rollouts: Store rollouts information for debugging. This will also store the
|
||||
trajectory taken by the end-effector across the horizon.
|
||||
use_cuda_graph_full_step: Capture full step in MPC as a single CUDA graph. This is
|
||||
experimental and might not work reliably.
|
||||
sync_cuda_time: Synchronize CUDA device with host using
|
||||
:py:func:`torch.cuda.synchronize` before calculating compute time.
|
||||
collision_cache: Cache of obstacles to create to load obstacles between planning calls.
|
||||
An example: ``{"obb": 10, "mesh": 10}``, to create a cache of 10 cuboids and 10
|
||||
meshes.
|
||||
n_collision_envs: Number of collision environments to create for batched planning
|
||||
across different environments. Only used for :py:meth:`MpcSolver.setup_solve_batch_env`
|
||||
and :py:meth:`MpcSolver.setup_solve_batch_env_goalset`.
|
||||
collision_activation_distance: Distance in meters to activate collision cost. A good
|
||||
value to start with is 0.01 meters. Increase the distance if the robot needs to
|
||||
stay further away from obstacles.
|
||||
world_coll_checker: Instance of world collision checker to use for MPC. Leaving this to
|
||||
None will create a new instance of world collision checker using the provided
|
||||
:attr:`world_model`.
|
||||
step_dt: Time step to use between each step in the trajectory. If None, the default
|
||||
time step from the configuration~(`particle_mpc.yml` or `gradient_mpc.yml`)
|
||||
is used. This dt should match the control frequency at which you are sending
|
||||
commands to the robot. This dt should also be greater than than the compute
|
||||
time for a single step.
|
||||
use_lbfgs: Use L-BFGS solver for MPC. Highly experimental.
|
||||
use_mppi: Use MPPI solver for MPC.
|
||||
|
||||
Returns:
|
||||
MpcSolverConfig: Configuration for the MPC solver.
|
||||
"""
|
||||
|
||||
if use_cuda_graph_full_step:
|
||||
log_error("use_cuda_graph_full_step currently is not supported")
|
||||
raise ValueError("use_cuda_graph_full_step currently is not supported")
|
||||
|
||||
task_file = "particle_mpc.yml"
|
||||
config_data = load_yaml(join_path(get_task_configs_path(), task_file))
|
||||
@@ -108,14 +189,14 @@ class MpcSolverConfig:
|
||||
if isinstance(robot_cfg, dict):
|
||||
robot_cfg = RobotConfig.from_dict(robot_cfg, tensor_args)
|
||||
|
||||
if isinstance(world_cfg, str):
|
||||
world_cfg = load_yaml(join_path(get_world_configs_path(), world_cfg))
|
||||
if isinstance(world_model, str):
|
||||
world_model = load_yaml(join_path(get_world_configs_path(), world_model))
|
||||
|
||||
if world_coll_checker is None and world_cfg is not None:
|
||||
world_cfg = WorldCollisionConfig.load_from_dict(
|
||||
base_cfg["world_collision_checker_cfg"], world_cfg, tensor_args
|
||||
if world_coll_checker is None and world_model is not None:
|
||||
world_model = WorldCollisionConfig.load_from_dict(
|
||||
base_cfg["world_collision_checker_cfg"], world_model, tensor_args
|
||||
)
|
||||
world_coll_checker = create_collision_checker(world_cfg)
|
||||
world_coll_checker = create_collision_checker(world_model)
|
||||
grad_config_data = None
|
||||
if use_lbfgs:
|
||||
grad_config_data = load_yaml(join_path(get_task_configs_path(), "gradient_mpc.yml"))
|
||||
@@ -134,7 +215,7 @@ class MpcSolverConfig:
|
||||
base_cfg["constraint"],
|
||||
base_cfg["convergence"],
|
||||
base_cfg["world_collision_checker_cfg"],
|
||||
world_cfg,
|
||||
world_model,
|
||||
world_coll_checker=world_coll_checker,
|
||||
tensor_args=tensor_args,
|
||||
)
|
||||
@@ -172,7 +253,7 @@ class MpcSolverConfig:
|
||||
base_cfg["constraint"],
|
||||
base_cfg["convergence"],
|
||||
base_cfg["world_collision_checker_cfg"],
|
||||
world_cfg,
|
||||
world_model,
|
||||
world_coll_checker=world_coll_checker,
|
||||
tensor_args=tensor_args,
|
||||
)
|
||||
@@ -201,13 +282,42 @@ class MpcSolverConfig:
|
||||
|
||||
|
||||
class MpcSolver(MpcSolverConfig):
|
||||
"""Model Predictive Control Solver for Arm Reacher task.
|
||||
"""High-level interface for Model Predictive Control (MPC).
|
||||
|
||||
Args:
|
||||
MpcSolverConfig: _description_
|
||||
MPC can reach Cartesian poses and joint configurations while avoiding obstacles. The solver
|
||||
uses Model Predictive Path Integral (MPPI) optimization as the solver. MPC only optimizes
|
||||
locally so the robot can get stuck near joint limits or behind obstacles. To generate global
|
||||
trajectories, use :py:meth:`~curobo.wrap.reacher.motion_gen.MotionGen`.
|
||||
|
||||
See :ref:`python_mpc_example` for an example. This MPC solver implementation can be used in the
|
||||
following steps:
|
||||
|
||||
1. Create a :py:class:`~curobo.rollout.rollout_base.Goal` object with the target pose or joint
|
||||
configuration.
|
||||
2. Create a goal buffer for the problem type using :meth:`setup_solve_single`,
|
||||
:meth:`setup_solve_goalset`, :meth:`setup_solve_batch`, :meth:`setup_solve_batch_goalset`,
|
||||
:meth:`setup_solve_batch_env`, or :meth:`setup_solve_batch_env_goalset`. Pass the goal
|
||||
object from the previous step to this function. This function will update the internal
|
||||
solve state of MPC and also the goal for MPC. An augmented goal buffer is returned.
|
||||
3. Call :meth:`step` with the current joint state to get the next action.
|
||||
4. To change the goal, create a :py:class:`~curobo.types.math.Pose` object with new pose or
|
||||
:py:class:`~curobo.types.state.JointState` object with new joint configuration. Then
|
||||
copy the target into the augmented goal buffer using
|
||||
``goal_buffer.goal_pose.copy_(new_pose)`` or ``goal_buffer.goal_state.copy_(new_state)``.
|
||||
5. Call :meth:`update_goal` with the augmented goal buffer to update the goal for MPC.
|
||||
6. Call :meth:`step` with the current joint state to get the next action.
|
||||
|
||||
To dynamically change the type of goal reached between pose and joint configuration targets,
|
||||
create the goal object in step 1 with both targets and then use :meth:`enable_cspace_cost` and
|
||||
:meth:`enable_pose_cost` to enable or disable reaching joint configuration cost and pose cost.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MpcSolverConfig) -> None:
|
||||
"""Initializes the MPC solver.
|
||||
|
||||
Args:
|
||||
config: Configuration parameters for MPC.
|
||||
"""
|
||||
super().__init__(**vars(config))
|
||||
self.tensor_args = self.solver.rollout_fn.tensor_args
|
||||
self._goal_buffer = Goal()
|
||||
@@ -222,15 +332,326 @@ class MpcSolver(MpcSolverConfig):
|
||||
self._cu_step_graph = None
|
||||
self._cu_result = None
|
||||
|
||||
def _update_batch_size(self, batch_size):
|
||||
if self.batch_size != batch_size:
|
||||
self.batch_size = batch_size
|
||||
def setup_solve_single(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
"""Creates a goal buffer to solve for a robot to reach target pose or joint configuration.
|
||||
|
||||
def update_goal_buffer(
|
||||
Args:
|
||||
goal: goal object containing target pose or joint configuration.
|
||||
num_seeds: Number of seeds to use in the solver.
|
||||
|
||||
Returns:
|
||||
Goal: Instance of augmented goal buffer.
|
||||
"""
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.SINGLE, num_mpc_seeds=num_seeds, batch_size=1, n_envs=1, n_goalset=1
|
||||
)
|
||||
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
|
||||
|
||||
self.update_goal(goal_buffer)
|
||||
|
||||
return goal_buffer
|
||||
|
||||
def setup_solve_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
"""Creates a goal buffer to solve for a robot to reach a pose in a set of poses.
|
||||
|
||||
Args:
|
||||
goal: goal object containing target goalset or joint configuration.
|
||||
num_seeds: Number of seeds to use in the solver.
|
||||
|
||||
Returns:
|
||||
Goal: Instance of augmented goal buffer.
|
||||
"""
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.GOALSET,
|
||||
num_mpc_seeds=num_seeds,
|
||||
batch_size=1,
|
||||
n_envs=1,
|
||||
n_goalset=goal.n_goalset,
|
||||
)
|
||||
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
|
||||
self.update_goal(goal_buffer)
|
||||
return goal_buffer
|
||||
|
||||
def setup_solve_batch(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
"""Creates a goal buffer to solve for a batch of robots to reach targets.
|
||||
|
||||
Args:
|
||||
goal: goal object containing target poses or joint configurations.
|
||||
num_seeds: Number of seeds to use in the solver.
|
||||
|
||||
Returns:
|
||||
Goal: Instance of augmented goal buffer.
|
||||
"""
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.BATCH,
|
||||
num_mpc_seeds=num_seeds,
|
||||
batch_size=goal.batch,
|
||||
n_envs=1,
|
||||
n_goalset=1,
|
||||
)
|
||||
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
|
||||
self.update_goal(goal_buffer)
|
||||
|
||||
return goal_buffer
|
||||
|
||||
def setup_solve_batch_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
"""Creates a goal buffer to solve for a batch of robots to reach a set of poses.
|
||||
|
||||
Args:
|
||||
goal: goal object containing target goalset or joint configurations.
|
||||
num_seeds: Number of seeds to use in the solver.
|
||||
|
||||
Returns:
|
||||
Goal: Instance of augmented goal buffer.
|
||||
"""
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.BATCH_GOALSET,
|
||||
num_mpc_seeds=num_seeds,
|
||||
batch_size=goal.batch,
|
||||
n_envs=1,
|
||||
n_goalset=goal.n_goalset,
|
||||
)
|
||||
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
|
||||
self.update_goal(goal_buffer)
|
||||
|
||||
return goal_buffer
|
||||
|
||||
def setup_solve_batch_env(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
"""Creates a goal buffer to solve for a batch of robots in different collision worlds.
|
||||
|
||||
Args:
|
||||
goal: goal object containing target poses or joint configurations.
|
||||
num_seeds: Number of seeds to use in the solver.
|
||||
|
||||
Returns:
|
||||
Goal: Instance of augmented goal buffer.
|
||||
"""
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.BATCH_ENV,
|
||||
num_mpc_seeds=num_seeds,
|
||||
batch_size=goal.batch,
|
||||
n_envs=1,
|
||||
n_goalset=1,
|
||||
)
|
||||
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
|
||||
self.update_goal(goal_buffer)
|
||||
|
||||
return goal_buffer
|
||||
|
||||
def setup_solve_batch_env_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
"""Creates a goal buffer to solve for a batch of robots in different collision worlds.
|
||||
|
||||
Args:
|
||||
goal: goal object containing target goalset or joint configurations.
|
||||
num_seeds: Number of seeds to use in the solver.
|
||||
|
||||
Returns:
|
||||
Goal: Instance of augmented goal buffer.
|
||||
"""
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.BATCH_ENV_GOALSET,
|
||||
num_mpc_seeds=num_seeds,
|
||||
batch_size=goal.batch,
|
||||
n_envs=1,
|
||||
n_goalset=goal.n_goalset,
|
||||
)
|
||||
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
|
||||
self.update_goal(goal_buffer)
|
||||
|
||||
return goal_buffer
|
||||
|
||||
def step(
|
||||
self,
|
||||
current_state: JointState,
|
||||
shift_steps: int = 1,
|
||||
seed_traj: Optional[JointState] = None,
|
||||
max_attempts: int = 1,
|
||||
):
|
||||
"""Solve for the next action given the current state.
|
||||
|
||||
Args:
|
||||
current_state: Current joint state of the robot.
|
||||
shift_steps: Number of steps to shift the trajectory.
|
||||
seed_traj: Initial trajectory to seed the optimization. If None, the solver
|
||||
uses the solution from the previous step.
|
||||
max_attempts: Maximum number of attempts to solve the problem.
|
||||
|
||||
Returns:
|
||||
WrapResult: Result of the optimization.
|
||||
"""
|
||||
converged = True
|
||||
|
||||
for _ in range(max_attempts):
|
||||
result = self._step_once(current_state.clone(), shift_steps, seed_traj)
|
||||
if (
|
||||
torch.count_nonzero(torch.isnan(result.action.position)) == 0
|
||||
and torch.count_nonzero(~result.metrics.feasible) == 0
|
||||
):
|
||||
converged = True
|
||||
break
|
||||
self.reset()
|
||||
if not converged:
|
||||
result.action.copy_(current_state)
|
||||
log_warn("MPC didn't converge")
|
||||
|
||||
return result
|
||||
|
||||
def update_goal(self, goal: Goal):
|
||||
"""Update the goal for MPC.
|
||||
|
||||
Args:
|
||||
goal: goal object containing target pose or joint configuration. This goal instance
|
||||
should be created using one of the setup_solve functions.
|
||||
"""
|
||||
self.solver.update_params(goal)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the solver."""
|
||||
# reset warm start
|
||||
self.solver.reset()
|
||||
|
||||
def reset_cuda_graph(self):
|
||||
"""Reset captured CUDA graph. This does not work."""
|
||||
self.solver.reset_cuda_graph()
|
||||
|
||||
def enable_cspace_cost(self, enable=True):
|
||||
"""Enable or disable reaching joint configuration cost in the solver.
|
||||
|
||||
Args:
|
||||
enable: Enable or disable reaching joint configuration cost. When False, cspace cost
|
||||
is disabled.
|
||||
"""
|
||||
self.solver.safety_rollout.enable_cspace_cost(enable)
|
||||
for opt in self.solver.optimizers:
|
||||
opt.rollout_fn.enable_cspace_cost(enable)
|
||||
|
||||
def enable_pose_cost(self, enable=True):
|
||||
"""Enable or disable reaching pose cost in the solver.
|
||||
|
||||
Args:
|
||||
enable: Enable or disable reaching pose cost. When False, pose cost is disabled.
|
||||
"""
|
||||
self.solver.safety_rollout.enable_pose_cost(enable)
|
||||
for opt in self.solver.optimizers:
|
||||
opt.rollout_fn.enable_pose_cost(enable)
|
||||
|
||||
def get_active_js(
|
||||
self,
|
||||
in_js: JointState,
|
||||
):
|
||||
"""Get controlled joints indexed in MPC order from the input joint state.
|
||||
|
||||
Args:
|
||||
in_js: Input joint state.
|
||||
|
||||
Returns:
|
||||
JointState: Joint state with controlled joints.
|
||||
"""
|
||||
|
||||
opt_jnames = self.rollout_fn.joint_names
|
||||
opt_js = in_js.get_ordered_joint_state(opt_jnames)
|
||||
return opt_js
|
||||
|
||||
def update_world(self, world: WorldConfig):
|
||||
"""Update the collision world for the solver.
|
||||
|
||||
This allows for updating the world representation as long as the new world representation
|
||||
does not have a larger number of obstacles than the :attr:`MpcSolver.collision_cache` as
|
||||
created during initialization of :class:`MpcSolverConfig`.
|
||||
|
||||
Args:
|
||||
world: New collision world configuration. See :ref:`world_collision` for more details.
|
||||
"""
|
||||
self.world_coll_checker.load_collision_model(world)
|
||||
|
||||
def get_visual_rollouts(self):
|
||||
"""Get rollouts for debugging."""
|
||||
return self.solver.optimizers[0].get_rollouts()
|
||||
|
||||
@property
|
||||
def joint_names(self):
|
||||
"""Get the ordered joint names of the robot."""
|
||||
return self.rollout_fn.joint_names
|
||||
|
||||
@property
|
||||
def collision_cache(self) -> Dict[str, int]:
|
||||
"""Returns the collision cache created by the world collision checker."""
|
||||
return self.world_coll_checker.cache
|
||||
|
||||
@property
|
||||
def kinematics(self) -> CudaRobotModel:
|
||||
"""Get kinematics instance of the robot."""
|
||||
return self.solver.safety_rollout.dynamics_model.robot_model
|
||||
|
||||
@property
|
||||
def world_collision(self) -> WorldCollision:
|
||||
"""Get the world collision checker."""
|
||||
return self.world_coll_checker
|
||||
|
||||
@property
|
||||
def rollout_fn(self) -> ArmReacher:
|
||||
"""Get the rollout function."""
|
||||
return self.solver.safety_rollout
|
||||
|
||||
def _step_once(
|
||||
self,
|
||||
current_state: JointState,
|
||||
shift_steps: int = 1,
|
||||
seed_traj: Optional[JointState] = None,
|
||||
) -> WrapResult:
|
||||
"""Solve for the next action given the current state.
|
||||
|
||||
Args:
|
||||
current_state: Current joint state of the robot.
|
||||
shift_steps: Number of steps to shift the trajectory.
|
||||
seed_traj: Initial trajectory to seed the optimization. If None, the solver
|
||||
uses the solution from the previous step.
|
||||
|
||||
Returns:
|
||||
WrapResult: Result of the optimization.
|
||||
"""
|
||||
# Create cuda graph for whole solve step including computation of metrics
|
||||
# Including updation of goal buffers
|
||||
|
||||
if self._solve_state is None:
|
||||
log_error("Need to first setup solve state before calling solve()")
|
||||
|
||||
if self.use_cuda_graph_full_step:
|
||||
st_time = time.time()
|
||||
if not self._cu_step_init:
|
||||
self._initialize_cuda_graph_step(current_state, shift_steps, seed_traj)
|
||||
self._cu_state_in.copy_(current_state)
|
||||
if seed_traj is not None:
|
||||
self._cu_seed.copy_(seed_traj)
|
||||
self._cu_step_graph.replay()
|
||||
result = self._cu_result.clone()
|
||||
torch.cuda.synchronize(device=self.tensor_args.device)
|
||||
result.solve_time = time.time() - st_time
|
||||
else:
|
||||
self._step_goal_buffer.current_state.copy_(current_state)
|
||||
result = self._solve_from_solve_state(
|
||||
self._solve_state,
|
||||
self._step_goal_buffer,
|
||||
shift_steps,
|
||||
seed_traj,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _update_solve_state_and_goal_buffer(
|
||||
self,
|
||||
solve_state: ReacherSolveState,
|
||||
goal: Goal,
|
||||
) -> Goal:
|
||||
"""Update solve state and goal for MPC.
|
||||
|
||||
Args:
|
||||
solve_state: New solve state.
|
||||
goal: New goal buffer.
|
||||
|
||||
Returns:
|
||||
Goal: Updated goal buffer.
|
||||
"""
|
||||
self._solve_state, self._goal_buffer, update_reference = solve_state.update_goal(
|
||||
goal,
|
||||
self._solve_state,
|
||||
@@ -250,71 +671,64 @@ class MpcSolver(MpcSolverConfig):
|
||||
)
|
||||
return self._goal_buffer
|
||||
|
||||
def step(
|
||||
def _update_batch_size(self, batch_size: int):
|
||||
"""Update the batch size of the solver.
|
||||
|
||||
Args:
|
||||
batch_size: Number of problems to solve in parallel.
|
||||
"""
|
||||
if self.batch_size != batch_size:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def _solve_from_solve_state(
|
||||
self,
|
||||
current_state: JointState,
|
||||
shift_steps: int = 1,
|
||||
seed_traj: Optional[JointState] = None,
|
||||
max_attempts: int = 1,
|
||||
):
|
||||
converged = True
|
||||
|
||||
for _ in range(max_attempts):
|
||||
result = self.step_once(current_state.clone(), shift_steps, seed_traj)
|
||||
if (
|
||||
torch.count_nonzero(torch.isnan(result.action.position)) == 0
|
||||
and torch.max(torch.abs(result.action.position)) < 10.0
|
||||
and torch.count_nonzero(~result.metrics.feasible) == 0
|
||||
):
|
||||
converged = True
|
||||
break
|
||||
self.reset()
|
||||
if not converged:
|
||||
result.action.copy_(current_state)
|
||||
log_warn("NOT CONVERGED")
|
||||
|
||||
return result
|
||||
|
||||
def step_once(
|
||||
self,
|
||||
current_state: JointState,
|
||||
solve_state: ReacherSolveState,
|
||||
goal: Goal,
|
||||
shift_steps: int = 1,
|
||||
seed_traj: Optional[JointState] = None,
|
||||
) -> WrapResult:
|
||||
# Create cuda graph for whole solve step including computation of metrics
|
||||
# Including updation of goal buffers
|
||||
"""Solve for the next action given the current state.
|
||||
|
||||
if self._solve_state is None:
|
||||
log_error("Need to first setup solve state before calling solve()")
|
||||
Args:
|
||||
solve_state: solve state object containing information about the current MPC problem.
|
||||
goal: goal object containing target pose or joint configuration.
|
||||
shift_steps: Number of steps to shift the trajectory before optimization.
|
||||
seed_traj: Initial trajectory to seed the optimization. If None, the solver
|
||||
uses the solution from the previous step.
|
||||
|
||||
if self.use_cuda_graph_full_step:
|
||||
st_time = time.time()
|
||||
if not self._cu_step_init:
|
||||
self._initialize_cuda_graph_step(current_state, shift_steps, seed_traj)
|
||||
self._cu_state_in.copy_(current_state)
|
||||
if seed_traj is not None:
|
||||
self._cu_seed.copy_(seed_traj)
|
||||
self._cu_step_graph.replay()
|
||||
result = self._cu_result.clone()
|
||||
torch.cuda.synchronize()
|
||||
result.solve_time = time.time() - st_time
|
||||
else:
|
||||
self._step_goal_buffer.current_state.copy_(current_state)
|
||||
result = self._solve_from_solve_state(
|
||||
self._solve_state,
|
||||
self._step_goal_buffer,
|
||||
shift_steps,
|
||||
seed_traj,
|
||||
)
|
||||
Returns:
|
||||
WrapResult: Result of the optimization.
|
||||
"""
|
||||
if solve_state.batch_env:
|
||||
if solve_state.batch_size > self.world_coll_checker.n_envs:
|
||||
log_error("Batch Env is less that goal batch")
|
||||
|
||||
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
|
||||
|
||||
if seed_traj is not None:
|
||||
self.solver.update_init_seed(seed_traj)
|
||||
|
||||
result = self.solver.solve(goal_buffer, seed_traj, shift_steps)
|
||||
result.js_action = self.rollout_fn.get_full_dof_from_solution(result.action)
|
||||
return result
|
||||
|
||||
def _step(
|
||||
def _mpc_step(
|
||||
self,
|
||||
current_state: JointState,
|
||||
shift_steps: int = 1,
|
||||
seed_traj: Optional[JointState] = None,
|
||||
):
|
||||
"""One step function that is used to create a CUDA graph.
|
||||
|
||||
Args:
|
||||
current_state: Current joint state of the robot.
|
||||
shift_steps: Number of steps to shift the trajectory.
|
||||
seed_traj: Initial trajectory to seed the optimization. If None, the solver
|
||||
uses the solution from the previous step.
|
||||
|
||||
Returns:
|
||||
WrapResult: Result of the optimization.
|
||||
"""
|
||||
self._step_goal_buffer.current_state.copy_(current_state)
|
||||
result = self._solve_from_solve_state(
|
||||
self._solve_state,
|
||||
@@ -331,6 +745,14 @@ class MpcSolver(MpcSolverConfig):
|
||||
shift_steps: int = 1,
|
||||
seed_traj: Optional[JointState] = None,
|
||||
):
|
||||
"""Create a CUDA graph for the full step of MPC.
|
||||
|
||||
Args:
|
||||
current_state: Current joint state of the robot.
|
||||
shift_steps: Number of steps to shift the trajectory.
|
||||
seed_traj: Initial trajectory to seed the optimization. If None, the solver
|
||||
uses the solution from the previous step.
|
||||
"""
|
||||
log_info("MpcSolver: Creating Cuda Graph")
|
||||
self._cu_state_in = current_state.clone()
|
||||
if seed_traj is not None:
|
||||
@@ -339,7 +761,7 @@ class MpcSolver(MpcSolverConfig):
|
||||
s.wait_stream(torch.cuda.current_stream(device=self.tensor_args.device))
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
self._cu_result = self._step(
|
||||
self._cu_result = self._mpc_step(
|
||||
self._cu_state_in, shift_steps=shift_steps, seed_traj=self._cu_seed
|
||||
)
|
||||
torch.cuda.current_stream(device=self.tensor_args.device).wait_stream(s)
|
||||
@@ -350,142 +772,3 @@ class MpcSolver(MpcSolverConfig):
|
||||
self._cu_state_in, shift_steps=shift_steps, seed_traj=self._cu_seed
|
||||
)
|
||||
self._cu_step_init = True
|
||||
|
||||
def setup_solve_single(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.SINGLE, num_mpc_seeds=num_seeds, batch_size=1, n_envs=1, n_goalset=1
|
||||
)
|
||||
goal_buffer = self.update_goal_buffer(solve_state, goal)
|
||||
return goal_buffer
|
||||
|
||||
def setup_solve_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.GOALSET,
|
||||
num_mpc_seeds=num_seeds,
|
||||
batch_size=1,
|
||||
n_envs=1,
|
||||
n_goalset=goal.n_goalset,
|
||||
)
|
||||
goal_buffer = self.update_goal_buffer(solve_state, goal)
|
||||
return goal_buffer
|
||||
|
||||
def setup_solve_batch(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.BATCH,
|
||||
num_mpc_seeds=num_seeds,
|
||||
batch_size=goal.batch,
|
||||
n_envs=1,
|
||||
n_goalset=1,
|
||||
)
|
||||
goal_buffer = self.update_goal_buffer(solve_state, goal)
|
||||
return goal_buffer
|
||||
|
||||
def setup_solve_batch_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.BATCH_GOALSET,
|
||||
num_mpc_seeds=num_seeds,
|
||||
batch_size=goal.batch,
|
||||
n_envs=1,
|
||||
n_goalset=goal.n_goalset,
|
||||
)
|
||||
goal_buffer = self.update_goal_buffer(solve_state, goal)
|
||||
return goal_buffer
|
||||
|
||||
def setup_solve_batch_env(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.BATCH_ENV,
|
||||
num_mpc_seeds=num_seeds,
|
||||
batch_size=goal.batch,
|
||||
n_envs=1,
|
||||
n_goalset=1,
|
||||
)
|
||||
goal_buffer = self.update_goal_buffer(solve_state, goal)
|
||||
return goal_buffer
|
||||
|
||||
def setup_solve_batch_env_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
|
||||
solve_state = ReacherSolveState(
|
||||
ReacherSolveType.BATCH_ENV_GOALSET,
|
||||
num_mpc_seeds=num_seeds,
|
||||
batch_size=goal.batch,
|
||||
n_envs=1,
|
||||
n_goalset=goal.n_goalset,
|
||||
)
|
||||
goal_buffer = self.update_goal_buffer(solve_state, goal)
|
||||
return goal_buffer
|
||||
|
||||
def _solve_from_solve_state(
|
||||
self,
|
||||
solve_state: ReacherSolveState,
|
||||
goal: Goal,
|
||||
shift_steps: int = 1,
|
||||
seed_traj: Optional[JointState] = None,
|
||||
) -> WrapResult:
|
||||
if solve_state.batch_env:
|
||||
if solve_state.batch_size > self.world_coll_checker.n_envs:
|
||||
raise ValueError("Batch Env is less that goal batch")
|
||||
|
||||
goal_buffer = self.update_goal_buffer(solve_state, goal)
|
||||
# NOTE: implement initialization from seed set here:
|
||||
if seed_traj is not None:
|
||||
self.solver.update_init_seed(seed_traj)
|
||||
|
||||
result = self.solver.solve(goal_buffer, seed_traj, shift_steps)
|
||||
result.js_action = self.rollout_fn.get_full_dof_from_solution(result.action)
|
||||
return result
|
||||
|
||||
def fn(self):
|
||||
# this will run one step of optimization and get new command
|
||||
pass
|
||||
|
||||
def update_goal(self, goal: Goal):
|
||||
self.solver.update_params(goal)
|
||||
return True
|
||||
|
||||
def reset(self):
|
||||
# reset warm start
|
||||
self.solver.reset()
|
||||
pass
|
||||
|
||||
def reset_cuda_graph(self):
|
||||
self.solver.reset_cuda_graph()
|
||||
|
||||
@property
|
||||
def rollout_fn(self):
|
||||
return self.solver.safety_rollout
|
||||
|
||||
def enable_cspace_cost(self, enable=True):
|
||||
self.solver.safety_rollout.enable_cspace_cost(enable)
|
||||
for opt in self.solver.optimizers:
|
||||
opt.rollout_fn.enable_cspace_cost(enable)
|
||||
|
||||
def enable_pose_cost(self, enable=True):
|
||||
self.solver.safety_rollout.enable_pose_cost(enable)
|
||||
for opt in self.solver.optimizers:
|
||||
opt.rollout_fn.enable_pose_cost(enable)
|
||||
|
||||
def get_active_js(
|
||||
self,
|
||||
in_js: JointState,
|
||||
):
|
||||
opt_jnames = self.rollout_fn.joint_names
|
||||
opt_js = in_js.get_ordered_joint_state(opt_jnames)
|
||||
return opt_js
|
||||
|
||||
@property
|
||||
def joint_names(self):
|
||||
return self.rollout_fn.joint_names
|
||||
|
||||
def update_world(self, world: WorldConfig):
|
||||
self.world_coll_checker.load_collision_model(world)
|
||||
return True
|
||||
|
||||
def get_visual_rollouts(self):
|
||||
return self.solver.optimizers[0].get_rollouts()
|
||||
|
||||
@property
|
||||
def kinematics(self):
|
||||
return self.solver.safety_rollout.dynamics_model.robot_model
|
||||
|
||||
@property
|
||||
def world_collision(self):
|
||||
return self.world_coll_checker
|
||||
|
||||
Reference in New Issue
Block a user