Add re-timing, minimum dt robustness

This commit is contained in:
Balakumar Sundaralingam
2024-04-25 12:24:17 -07:00
parent d6e600c88c
commit 7362ccd4c2
54 changed files with 4773 additions and 2189 deletions

View File

@@ -8,7 +8,31 @@
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
"""
This module contains :meth:`MpcSolver` that provides a high-level interface to for model
predictive control (MPC) for reaching Cartesian poses and also joint configurations while
avoiding obstacles. The solver uses Model Predictive Path Integral (MPPI) optimization as the
solver. MPC only optimizes locally so the robot can get stuck near joint limits or behind
obstacles. To generate global trajectories, use
:py:meth:`~curobo.wrap.reacher.motion_gen.MotionGen`.
A python example is available at :ref:`python_mpc_example`.
.. note::
Gradient-based MPC is also implemented with L-BFGS but is highly experimental and not
recommended for real robots.
.. raw:: html
<p>
<video autoplay="True" loop="True" muted="True" preload="auto" width="100%"><source src="../videos/mpc_clip.mp4" type="video/mp4"></video>
</p>
"""
# Standard Library
import time
@@ -19,6 +43,7 @@ from typing import Dict, Optional, Union
import torch
# CuRobo
from curobo.cuda_robot_model.cuda_robot_model import CudaRobotModel
from curobo.geom.sdf.utils import create_collision_checker
from curobo.geom.sdf.world import CollisionCheckerType, WorldCollision, WorldCollisionConfig
from curobo.geom.types import WorldConfig
@@ -44,22 +69,31 @@ from curobo.wrap.wrap_mpc import WrapConfig, WrapMpc
@dataclass
class MpcSolverConfig:
"""Configuration dataclass for MPC."""
#: MPC Solver.
solver: WrapMpc
#: World Collision Checker.
world_coll_checker: Optional[WorldCollision] = None
#: Numeric precision and device to run computations.
tensor_args: TensorDeviceType = TensorDeviceType()
#: Capture full step in MPC as a single CUDA graph. This is not supported currently.
use_cuda_graph_full_step: bool = False
@staticmethod
def load_from_robot_config(
robot_cfg: Union[Union[str, dict], RobotConfig],
world_cfg: Union[Union[str, dict], WorldConfig],
world_model: Union[Union[str, dict], WorldConfig],
base_cfg: Optional[dict] = None,
tensor_args: TensorDeviceType = TensorDeviceType(),
compute_metrics: bool = True,
use_cuda_graph: Optional[bool] = None,
particle_opt_iters: Optional[int] = None,
self_collision_check: bool = True,
collision_checker_type: Optional[CollisionCheckerType] = CollisionCheckerType.PRIMITIVE,
collision_checker_type: Optional[CollisionCheckerType] = CollisionCheckerType.MESH,
use_es: Optional[bool] = None,
es_learning_rate: Optional[float] = 0.01,
use_cuda_graph_metrics: bool = False,
@@ -74,9 +108,56 @@ class MpcSolverConfig:
use_lbfgs: bool = False,
use_mppi: bool = True,
):
"""Create an MPC solver configuration from robot and world configuration.
Args:
robot_cfg: Robot configuration. Can be a path to a YAML file or a dictionary or
an instance of :class:`~curobo.types.robot.RobotConfig`.
world_model: World configuration. Can be a path to a YAML file or a dictionary or
an instance of :class:`~curobo.geom.types.WorldConfig`.
base_cfg: Base configuration for the solver. This file is used to check constraints
and convergence. If None, the default configuration from ``base_cfg.yml`` is used.
tensor_args: Numeric precision and device to run computations.
compute_metrics: Compute metrics on MPC step.
use_cuda_graph: Use CUDA graph for the optimization step.
particle_opt_iters: Number of iterations for the particle optimization.
self_collision_check: Enable self-collision check during MPC optimization.
collision_checker_type: Type of collision checker to use. See :ref:`world_collision`.
use_es: Use Evolution Strategies (ES) solver for MPC. Highly experimental.
es_learning_rate: Learning rate for ES solver.
use_cuda_graph_metrics: Use CUDA graph for computing metrics.
store_rollouts: Store rollouts information for debugging. This will also store the
trajectory taken by the end-effector across the horizon.
use_cuda_graph_full_step: Capture full step in MPC as a single CUDA graph. This is
experimental and might not work reliably.
sync_cuda_time: Synchronize CUDA device with host using
:py:func:`torch.cuda.synchronize` before calculating compute time.
collision_cache: Cache of obstacles to create to load obstacles between planning calls.
An example: ``{"obb": 10, "mesh": 10}``, to create a cache of 10 cuboids and 10
meshes.
n_collision_envs: Number of collision environments to create for batched planning
across different environments. Only used for :py:meth:`MpcSolver.setup_solve_batch_env`
and :py:meth:`MpcSolver.setup_solve_batch_env_goalset`.
collision_activation_distance: Distance in meters to activate collision cost. A good
value to start with is 0.01 meters. Increase the distance if the robot needs to
stay further away from obstacles.
world_coll_checker: Instance of world collision checker to use for MPC. Leaving this to
None will create a new instance of world collision checker using the provided
:attr:`world_model`.
step_dt: Time step to use between each step in the trajectory. If None, the default
time step from the configuration~(`particle_mpc.yml` or `gradient_mpc.yml`)
is used. This dt should match the control frequency at which you are sending
commands to the robot. This dt should also be greater than than the compute
time for a single step.
use_lbfgs: Use L-BFGS solver for MPC. Highly experimental.
use_mppi: Use MPPI solver for MPC.
Returns:
MpcSolverConfig: Configuration for the MPC solver.
"""
if use_cuda_graph_full_step:
log_error("use_cuda_graph_full_step currently is not supported")
raise ValueError("use_cuda_graph_full_step currently is not supported")
task_file = "particle_mpc.yml"
config_data = load_yaml(join_path(get_task_configs_path(), task_file))
@@ -108,14 +189,14 @@ class MpcSolverConfig:
if isinstance(robot_cfg, dict):
robot_cfg = RobotConfig.from_dict(robot_cfg, tensor_args)
if isinstance(world_cfg, str):
world_cfg = load_yaml(join_path(get_world_configs_path(), world_cfg))
if isinstance(world_model, str):
world_model = load_yaml(join_path(get_world_configs_path(), world_model))
if world_coll_checker is None and world_cfg is not None:
world_cfg = WorldCollisionConfig.load_from_dict(
base_cfg["world_collision_checker_cfg"], world_cfg, tensor_args
if world_coll_checker is None and world_model is not None:
world_model = WorldCollisionConfig.load_from_dict(
base_cfg["world_collision_checker_cfg"], world_model, tensor_args
)
world_coll_checker = create_collision_checker(world_cfg)
world_coll_checker = create_collision_checker(world_model)
grad_config_data = None
if use_lbfgs:
grad_config_data = load_yaml(join_path(get_task_configs_path(), "gradient_mpc.yml"))
@@ -134,7 +215,7 @@ class MpcSolverConfig:
base_cfg["constraint"],
base_cfg["convergence"],
base_cfg["world_collision_checker_cfg"],
world_cfg,
world_model,
world_coll_checker=world_coll_checker,
tensor_args=tensor_args,
)
@@ -172,7 +253,7 @@ class MpcSolverConfig:
base_cfg["constraint"],
base_cfg["convergence"],
base_cfg["world_collision_checker_cfg"],
world_cfg,
world_model,
world_coll_checker=world_coll_checker,
tensor_args=tensor_args,
)
@@ -201,13 +282,42 @@ class MpcSolverConfig:
class MpcSolver(MpcSolverConfig):
"""Model Predictive Control Solver for Arm Reacher task.
"""High-level interface for Model Predictive Control (MPC).
Args:
MpcSolverConfig: _description_
MPC can reach Cartesian poses and joint configurations while avoiding obstacles. The solver
uses Model Predictive Path Integral (MPPI) optimization as the solver. MPC only optimizes
locally so the robot can get stuck near joint limits or behind obstacles. To generate global
trajectories, use :py:meth:`~curobo.wrap.reacher.motion_gen.MotionGen`.
See :ref:`python_mpc_example` for an example. This MPC solver implementation can be used in the
following steps:
1. Create a :py:class:`~curobo.rollout.rollout_base.Goal` object with the target pose or joint
configuration.
2. Create a goal buffer for the problem type using :meth:`setup_solve_single`,
:meth:`setup_solve_goalset`, :meth:`setup_solve_batch`, :meth:`setup_solve_batch_goalset`,
:meth:`setup_solve_batch_env`, or :meth:`setup_solve_batch_env_goalset`. Pass the goal
object from the previous step to this function. This function will update the internal
solve state of MPC and also the goal for MPC. An augmented goal buffer is returned.
3. Call :meth:`step` with the current joint state to get the next action.
4. To change the goal, create a :py:class:`~curobo.types.math.Pose` object with new pose or
:py:class:`~curobo.types.state.JointState` object with new joint configuration. Then
copy the target into the augmented goal buffer using
``goal_buffer.goal_pose.copy_(new_pose)`` or ``goal_buffer.goal_state.copy_(new_state)``.
5. Call :meth:`update_goal` with the augmented goal buffer to update the goal for MPC.
6. Call :meth:`step` with the current joint state to get the next action.
To dynamically change the type of goal reached between pose and joint configuration targets,
create the goal object in step 1 with both targets and then use :meth:`enable_cspace_cost` and
:meth:`enable_pose_cost` to enable or disable reaching joint configuration cost and pose cost.
"""
def __init__(self, config: MpcSolverConfig) -> None:
"""Initializes the MPC solver.
Args:
config: Configuration parameters for MPC.
"""
super().__init__(**vars(config))
self.tensor_args = self.solver.rollout_fn.tensor_args
self._goal_buffer = Goal()
@@ -222,15 +332,326 @@ class MpcSolver(MpcSolverConfig):
self._cu_step_graph = None
self._cu_result = None
def _update_batch_size(self, batch_size):
if self.batch_size != batch_size:
self.batch_size = batch_size
def setup_solve_single(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
"""Creates a goal buffer to solve for a robot to reach target pose or joint configuration.
def update_goal_buffer(
Args:
goal: goal object containing target pose or joint configuration.
num_seeds: Number of seeds to use in the solver.
Returns:
Goal: Instance of augmented goal buffer.
"""
solve_state = ReacherSolveState(
ReacherSolveType.SINGLE, num_mpc_seeds=num_seeds, batch_size=1, n_envs=1, n_goalset=1
)
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
self.update_goal(goal_buffer)
return goal_buffer
def setup_solve_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
"""Creates a goal buffer to solve for a robot to reach a pose in a set of poses.
Args:
goal: goal object containing target goalset or joint configuration.
num_seeds: Number of seeds to use in the solver.
Returns:
Goal: Instance of augmented goal buffer.
"""
solve_state = ReacherSolveState(
ReacherSolveType.GOALSET,
num_mpc_seeds=num_seeds,
batch_size=1,
n_envs=1,
n_goalset=goal.n_goalset,
)
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
self.update_goal(goal_buffer)
return goal_buffer
def setup_solve_batch(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
"""Creates a goal buffer to solve for a batch of robots to reach targets.
Args:
goal: goal object containing target poses or joint configurations.
num_seeds: Number of seeds to use in the solver.
Returns:
Goal: Instance of augmented goal buffer.
"""
solve_state = ReacherSolveState(
ReacherSolveType.BATCH,
num_mpc_seeds=num_seeds,
batch_size=goal.batch,
n_envs=1,
n_goalset=1,
)
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
self.update_goal(goal_buffer)
return goal_buffer
def setup_solve_batch_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
"""Creates a goal buffer to solve for a batch of robots to reach a set of poses.
Args:
goal: goal object containing target goalset or joint configurations.
num_seeds: Number of seeds to use in the solver.
Returns:
Goal: Instance of augmented goal buffer.
"""
solve_state = ReacherSolveState(
ReacherSolveType.BATCH_GOALSET,
num_mpc_seeds=num_seeds,
batch_size=goal.batch,
n_envs=1,
n_goalset=goal.n_goalset,
)
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
self.update_goal(goal_buffer)
return goal_buffer
def setup_solve_batch_env(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
"""Creates a goal buffer to solve for a batch of robots in different collision worlds.
Args:
goal: goal object containing target poses or joint configurations.
num_seeds: Number of seeds to use in the solver.
Returns:
Goal: Instance of augmented goal buffer.
"""
solve_state = ReacherSolveState(
ReacherSolveType.BATCH_ENV,
num_mpc_seeds=num_seeds,
batch_size=goal.batch,
n_envs=1,
n_goalset=1,
)
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
self.update_goal(goal_buffer)
return goal_buffer
def setup_solve_batch_env_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
"""Creates a goal buffer to solve for a batch of robots in different collision worlds.
Args:
goal: goal object containing target goalset or joint configurations.
num_seeds: Number of seeds to use in the solver.
Returns:
Goal: Instance of augmented goal buffer.
"""
solve_state = ReacherSolveState(
ReacherSolveType.BATCH_ENV_GOALSET,
num_mpc_seeds=num_seeds,
batch_size=goal.batch,
n_envs=1,
n_goalset=goal.n_goalset,
)
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
self.update_goal(goal_buffer)
return goal_buffer
def step(
self,
current_state: JointState,
shift_steps: int = 1,
seed_traj: Optional[JointState] = None,
max_attempts: int = 1,
):
"""Solve for the next action given the current state.
Args:
current_state: Current joint state of the robot.
shift_steps: Number of steps to shift the trajectory.
seed_traj: Initial trajectory to seed the optimization. If None, the solver
uses the solution from the previous step.
max_attempts: Maximum number of attempts to solve the problem.
Returns:
WrapResult: Result of the optimization.
"""
converged = True
for _ in range(max_attempts):
result = self._step_once(current_state.clone(), shift_steps, seed_traj)
if (
torch.count_nonzero(torch.isnan(result.action.position)) == 0
and torch.count_nonzero(~result.metrics.feasible) == 0
):
converged = True
break
self.reset()
if not converged:
result.action.copy_(current_state)
log_warn("MPC didn't converge")
return result
def update_goal(self, goal: Goal):
"""Update the goal for MPC.
Args:
goal: goal object containing target pose or joint configuration. This goal instance
should be created using one of the setup_solve functions.
"""
self.solver.update_params(goal)
def reset(self):
"""Reset the solver."""
# reset warm start
self.solver.reset()
def reset_cuda_graph(self):
"""Reset captured CUDA graph. This does not work."""
self.solver.reset_cuda_graph()
def enable_cspace_cost(self, enable=True):
"""Enable or disable reaching joint configuration cost in the solver.
Args:
enable: Enable or disable reaching joint configuration cost. When False, cspace cost
is disabled.
"""
self.solver.safety_rollout.enable_cspace_cost(enable)
for opt in self.solver.optimizers:
opt.rollout_fn.enable_cspace_cost(enable)
def enable_pose_cost(self, enable=True):
"""Enable or disable reaching pose cost in the solver.
Args:
enable: Enable or disable reaching pose cost. When False, pose cost is disabled.
"""
self.solver.safety_rollout.enable_pose_cost(enable)
for opt in self.solver.optimizers:
opt.rollout_fn.enable_pose_cost(enable)
def get_active_js(
self,
in_js: JointState,
):
"""Get controlled joints indexed in MPC order from the input joint state.
Args:
in_js: Input joint state.
Returns:
JointState: Joint state with controlled joints.
"""
opt_jnames = self.rollout_fn.joint_names
opt_js = in_js.get_ordered_joint_state(opt_jnames)
return opt_js
def update_world(self, world: WorldConfig):
"""Update the collision world for the solver.
This allows for updating the world representation as long as the new world representation
does not have a larger number of obstacles than the :attr:`MpcSolver.collision_cache` as
created during initialization of :class:`MpcSolverConfig`.
Args:
world: New collision world configuration. See :ref:`world_collision` for more details.
"""
self.world_coll_checker.load_collision_model(world)
def get_visual_rollouts(self):
"""Get rollouts for debugging."""
return self.solver.optimizers[0].get_rollouts()
@property
def joint_names(self):
"""Get the ordered joint names of the robot."""
return self.rollout_fn.joint_names
@property
def collision_cache(self) -> Dict[str, int]:
"""Returns the collision cache created by the world collision checker."""
return self.world_coll_checker.cache
@property
def kinematics(self) -> CudaRobotModel:
"""Get kinematics instance of the robot."""
return self.solver.safety_rollout.dynamics_model.robot_model
@property
def world_collision(self) -> WorldCollision:
"""Get the world collision checker."""
return self.world_coll_checker
@property
def rollout_fn(self) -> ArmReacher:
"""Get the rollout function."""
return self.solver.safety_rollout
def _step_once(
self,
current_state: JointState,
shift_steps: int = 1,
seed_traj: Optional[JointState] = None,
) -> WrapResult:
"""Solve for the next action given the current state.
Args:
current_state: Current joint state of the robot.
shift_steps: Number of steps to shift the trajectory.
seed_traj: Initial trajectory to seed the optimization. If None, the solver
uses the solution from the previous step.
Returns:
WrapResult: Result of the optimization.
"""
# Create cuda graph for whole solve step including computation of metrics
# Including updation of goal buffers
if self._solve_state is None:
log_error("Need to first setup solve state before calling solve()")
if self.use_cuda_graph_full_step:
st_time = time.time()
if not self._cu_step_init:
self._initialize_cuda_graph_step(current_state, shift_steps, seed_traj)
self._cu_state_in.copy_(current_state)
if seed_traj is not None:
self._cu_seed.copy_(seed_traj)
self._cu_step_graph.replay()
result = self._cu_result.clone()
torch.cuda.synchronize(device=self.tensor_args.device)
result.solve_time = time.time() - st_time
else:
self._step_goal_buffer.current_state.copy_(current_state)
result = self._solve_from_solve_state(
self._solve_state,
self._step_goal_buffer,
shift_steps,
seed_traj,
)
return result
def _update_solve_state_and_goal_buffer(
self,
solve_state: ReacherSolveState,
goal: Goal,
) -> Goal:
"""Update solve state and goal for MPC.
Args:
solve_state: New solve state.
goal: New goal buffer.
Returns:
Goal: Updated goal buffer.
"""
self._solve_state, self._goal_buffer, update_reference = solve_state.update_goal(
goal,
self._solve_state,
@@ -250,71 +671,64 @@ class MpcSolver(MpcSolverConfig):
)
return self._goal_buffer
def step(
def _update_batch_size(self, batch_size: int):
"""Update the batch size of the solver.
Args:
batch_size: Number of problems to solve in parallel.
"""
if self.batch_size != batch_size:
self.batch_size = batch_size
def _solve_from_solve_state(
self,
current_state: JointState,
shift_steps: int = 1,
seed_traj: Optional[JointState] = None,
max_attempts: int = 1,
):
converged = True
for _ in range(max_attempts):
result = self.step_once(current_state.clone(), shift_steps, seed_traj)
if (
torch.count_nonzero(torch.isnan(result.action.position)) == 0
and torch.max(torch.abs(result.action.position)) < 10.0
and torch.count_nonzero(~result.metrics.feasible) == 0
):
converged = True
break
self.reset()
if not converged:
result.action.copy_(current_state)
log_warn("NOT CONVERGED")
return result
def step_once(
self,
current_state: JointState,
solve_state: ReacherSolveState,
goal: Goal,
shift_steps: int = 1,
seed_traj: Optional[JointState] = None,
) -> WrapResult:
# Create cuda graph for whole solve step including computation of metrics
# Including updation of goal buffers
"""Solve for the next action given the current state.
if self._solve_state is None:
log_error("Need to first setup solve state before calling solve()")
Args:
solve_state: solve state object containing information about the current MPC problem.
goal: goal object containing target pose or joint configuration.
shift_steps: Number of steps to shift the trajectory before optimization.
seed_traj: Initial trajectory to seed the optimization. If None, the solver
uses the solution from the previous step.
if self.use_cuda_graph_full_step:
st_time = time.time()
if not self._cu_step_init:
self._initialize_cuda_graph_step(current_state, shift_steps, seed_traj)
self._cu_state_in.copy_(current_state)
if seed_traj is not None:
self._cu_seed.copy_(seed_traj)
self._cu_step_graph.replay()
result = self._cu_result.clone()
torch.cuda.synchronize()
result.solve_time = time.time() - st_time
else:
self._step_goal_buffer.current_state.copy_(current_state)
result = self._solve_from_solve_state(
self._solve_state,
self._step_goal_buffer,
shift_steps,
seed_traj,
)
Returns:
WrapResult: Result of the optimization.
"""
if solve_state.batch_env:
if solve_state.batch_size > self.world_coll_checker.n_envs:
log_error("Batch Env is less that goal batch")
goal_buffer = self._update_solve_state_and_goal_buffer(solve_state, goal)
if seed_traj is not None:
self.solver.update_init_seed(seed_traj)
result = self.solver.solve(goal_buffer, seed_traj, shift_steps)
result.js_action = self.rollout_fn.get_full_dof_from_solution(result.action)
return result
def _step(
def _mpc_step(
self,
current_state: JointState,
shift_steps: int = 1,
seed_traj: Optional[JointState] = None,
):
"""One step function that is used to create a CUDA graph.
Args:
current_state: Current joint state of the robot.
shift_steps: Number of steps to shift the trajectory.
seed_traj: Initial trajectory to seed the optimization. If None, the solver
uses the solution from the previous step.
Returns:
WrapResult: Result of the optimization.
"""
self._step_goal_buffer.current_state.copy_(current_state)
result = self._solve_from_solve_state(
self._solve_state,
@@ -331,6 +745,14 @@ class MpcSolver(MpcSolverConfig):
shift_steps: int = 1,
seed_traj: Optional[JointState] = None,
):
"""Create a CUDA graph for the full step of MPC.
Args:
current_state: Current joint state of the robot.
shift_steps: Number of steps to shift the trajectory.
seed_traj: Initial trajectory to seed the optimization. If None, the solver
uses the solution from the previous step.
"""
log_info("MpcSolver: Creating Cuda Graph")
self._cu_state_in = current_state.clone()
if seed_traj is not None:
@@ -339,7 +761,7 @@ class MpcSolver(MpcSolverConfig):
s.wait_stream(torch.cuda.current_stream(device=self.tensor_args.device))
with torch.cuda.stream(s):
for _ in range(3):
self._cu_result = self._step(
self._cu_result = self._mpc_step(
self._cu_state_in, shift_steps=shift_steps, seed_traj=self._cu_seed
)
torch.cuda.current_stream(device=self.tensor_args.device).wait_stream(s)
@@ -350,142 +772,3 @@ class MpcSolver(MpcSolverConfig):
self._cu_state_in, shift_steps=shift_steps, seed_traj=self._cu_seed
)
self._cu_step_init = True
def setup_solve_single(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
solve_state = ReacherSolveState(
ReacherSolveType.SINGLE, num_mpc_seeds=num_seeds, batch_size=1, n_envs=1, n_goalset=1
)
goal_buffer = self.update_goal_buffer(solve_state, goal)
return goal_buffer
def setup_solve_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
solve_state = ReacherSolveState(
ReacherSolveType.GOALSET,
num_mpc_seeds=num_seeds,
batch_size=1,
n_envs=1,
n_goalset=goal.n_goalset,
)
goal_buffer = self.update_goal_buffer(solve_state, goal)
return goal_buffer
def setup_solve_batch(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
solve_state = ReacherSolveState(
ReacherSolveType.BATCH,
num_mpc_seeds=num_seeds,
batch_size=goal.batch,
n_envs=1,
n_goalset=1,
)
goal_buffer = self.update_goal_buffer(solve_state, goal)
return goal_buffer
def setup_solve_batch_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
solve_state = ReacherSolveState(
ReacherSolveType.BATCH_GOALSET,
num_mpc_seeds=num_seeds,
batch_size=goal.batch,
n_envs=1,
n_goalset=goal.n_goalset,
)
goal_buffer = self.update_goal_buffer(solve_state, goal)
return goal_buffer
def setup_solve_batch_env(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
solve_state = ReacherSolveState(
ReacherSolveType.BATCH_ENV,
num_mpc_seeds=num_seeds,
batch_size=goal.batch,
n_envs=1,
n_goalset=1,
)
goal_buffer = self.update_goal_buffer(solve_state, goal)
return goal_buffer
def setup_solve_batch_env_goalset(self, goal: Goal, num_seeds: Optional[int] = None) -> Goal:
solve_state = ReacherSolveState(
ReacherSolveType.BATCH_ENV_GOALSET,
num_mpc_seeds=num_seeds,
batch_size=goal.batch,
n_envs=1,
n_goalset=goal.n_goalset,
)
goal_buffer = self.update_goal_buffer(solve_state, goal)
return goal_buffer
def _solve_from_solve_state(
self,
solve_state: ReacherSolveState,
goal: Goal,
shift_steps: int = 1,
seed_traj: Optional[JointState] = None,
) -> WrapResult:
if solve_state.batch_env:
if solve_state.batch_size > self.world_coll_checker.n_envs:
raise ValueError("Batch Env is less that goal batch")
goal_buffer = self.update_goal_buffer(solve_state, goal)
# NOTE: implement initialization from seed set here:
if seed_traj is not None:
self.solver.update_init_seed(seed_traj)
result = self.solver.solve(goal_buffer, seed_traj, shift_steps)
result.js_action = self.rollout_fn.get_full_dof_from_solution(result.action)
return result
def fn(self):
# this will run one step of optimization and get new command
pass
def update_goal(self, goal: Goal):
self.solver.update_params(goal)
return True
def reset(self):
# reset warm start
self.solver.reset()
pass
def reset_cuda_graph(self):
self.solver.reset_cuda_graph()
@property
def rollout_fn(self):
return self.solver.safety_rollout
def enable_cspace_cost(self, enable=True):
self.solver.safety_rollout.enable_cspace_cost(enable)
for opt in self.solver.optimizers:
opt.rollout_fn.enable_cspace_cost(enable)
def enable_pose_cost(self, enable=True):
self.solver.safety_rollout.enable_pose_cost(enable)
for opt in self.solver.optimizers:
opt.rollout_fn.enable_pose_cost(enable)
def get_active_js(
self,
in_js: JointState,
):
opt_jnames = self.rollout_fn.joint_names
opt_js = in_js.get_ordered_joint_state(opt_jnames)
return opt_js
@property
def joint_names(self):
return self.rollout_fn.joint_names
def update_world(self, world: WorldConfig):
self.world_coll_checker.load_collision_model(world)
return True
def get_visual_rollouts(self):
return self.solver.optimizers[0].get_rollouts()
@property
def kinematics(self):
return self.solver.safety_rollout.dynamics_model.robot_model
@property
def world_collision(self):
return self.world_coll_checker