release repository
This commit is contained in:
185
src/curobo/wrap/wrap_base.py
Normal file
185
src/curobo/wrap/wrap_base.py
Normal file
@@ -0,0 +1,185 @@
|
||||
#
|
||||
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
# Standard Library
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
import torch.autograd.profiler as profiler
|
||||
|
||||
# CuRobo
|
||||
from curobo.opt.newton.newton_base import NewtonOptBase
|
||||
from curobo.opt.opt_base import Optimizer
|
||||
from curobo.opt.particle.particle_opt_base import ParticleOptBase
|
||||
from curobo.rollout.rollout_base import Goal, RolloutBase, RolloutMetrics
|
||||
from curobo.types.robot import State
|
||||
from curobo.util.logger import log_info
|
||||
|
||||
|
||||
@dataclass
|
||||
class WrapConfig:
|
||||
safety_rollout: RolloutBase
|
||||
optimizers: List[Optimizer]
|
||||
compute_metrics: bool
|
||||
use_cuda_graph_metrics: bool
|
||||
sync_cuda_time: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class WrapResult:
|
||||
action: State
|
||||
solve_time: float
|
||||
metrics: Optional[RolloutMetrics] = None
|
||||
debug: Any = None
|
||||
js_action: Optional[State] = None
|
||||
|
||||
def clone(self):
|
||||
return WrapResult(
|
||||
self.action.clone(), self.solve_time, self.metrics.clone(), debug=self.debug
|
||||
)
|
||||
|
||||
|
||||
class WrapBase(WrapConfig):
|
||||
def __init__(self, config: Optional[WrapConfig] = None):
|
||||
if config is not None:
|
||||
WrapConfig.__init__(self, **vars(config))
|
||||
self.n_envs = 1
|
||||
self.opt_dt = 0
|
||||
self._rollout_list = None
|
||||
self._opt_rollouts = None
|
||||
self._init_solver = False
|
||||
|
||||
def get_metrics(self, state: State, use_cuda_graph: bool = False) -> RolloutMetrics:
|
||||
if use_cuda_graph:
|
||||
log_info("Using cuda graph")
|
||||
return self.safety_rollout.get_metrics_cuda_graph(state)
|
||||
return self.safety_rollout.get_metrics(state)
|
||||
|
||||
def optimize(self, act_seq: torch.Tensor, shift_steps: int = 0) -> torch.Tensor:
|
||||
for opt in self.optimizers:
|
||||
act_seq = opt.optimize(act_seq, shift_steps)
|
||||
return act_seq
|
||||
|
||||
def get_debug_data(self):
|
||||
debug_list = []
|
||||
for opt in self.optimizers:
|
||||
debug_list.append(opt.debug)
|
||||
return debug_list
|
||||
|
||||
def get_debug_cost(self):
|
||||
debug_list = []
|
||||
for opt in self.optimizers:
|
||||
debug_list.append(opt.debug_cost)
|
||||
return debug_list
|
||||
|
||||
def update_nenvs(self, n_envs):
|
||||
if n_envs != self.n_envs:
|
||||
self.n_envs = n_envs
|
||||
for opt in self.optimizers:
|
||||
opt.update_nenvs(self.n_envs)
|
||||
|
||||
def update_params(self, goal: Goal):
|
||||
with profiler.record_function("wrap_base/safety/update_params"):
|
||||
log_info("Updating safety params")
|
||||
self.safety_rollout.update_params(goal)
|
||||
log_info("Updating optimizer params")
|
||||
for opt in self.optimizers:
|
||||
opt.update_params(goal)
|
||||
|
||||
def get_init_act(self):
|
||||
init_act = self.safety_rollout.get_init_action_seq()
|
||||
return init_act
|
||||
|
||||
def reset(self):
|
||||
for opt in self.optimizers:
|
||||
opt.reset()
|
||||
self.safety_rollout.reset()
|
||||
|
||||
def reset_seed(self):
|
||||
self.safety_rollout.reset_seed()
|
||||
for opt in self.optimizers:
|
||||
opt.reset_seed()
|
||||
|
||||
def reset_cuda_graph(self):
|
||||
self.safety_rollout.reset_cuda_graph()
|
||||
for opt in self.optimizers:
|
||||
opt.reset_cuda_graph()
|
||||
self._init_solver = False
|
||||
|
||||
@property
|
||||
def rollout_fn(self):
|
||||
return self.safety_rollout
|
||||
|
||||
def solve(self, goal: Goal, seed: Optional[torch.Tensor] = None):
|
||||
metrics = None
|
||||
|
||||
filtered_state = self.safety_rollout.filter_robot_state(goal.current_state)
|
||||
goal.current_state = filtered_state
|
||||
self.update_params(goal)
|
||||
if seed is None:
|
||||
seed = self.get_init_act()
|
||||
log_info("getting random seed")
|
||||
else:
|
||||
seed = seed.detach().clone()
|
||||
start_time = time.time()
|
||||
if not self._init_solver:
|
||||
log_info("Solver was not initialized, warming up solver")
|
||||
for _ in range(2):
|
||||
act_seq = self.optimize(seed, shift_steps=0)
|
||||
self._init_solver = True
|
||||
act_seq = self.optimize(seed, shift_steps=0)
|
||||
self.opt_dt = time.time() - start_time
|
||||
|
||||
act = self.safety_rollout.get_robot_command(
|
||||
filtered_state, act_seq, state_idx=goal.batch_current_state_idx
|
||||
)
|
||||
|
||||
if self.compute_metrics:
|
||||
with profiler.record_function("wrap_base/compute_metrics"):
|
||||
metrics = self.get_metrics(
|
||||
act, self.use_cuda_graph_metrics
|
||||
) # TODO: use cuda graph for metrics
|
||||
|
||||
result = WrapResult(
|
||||
action=act,
|
||||
solve_time=self.opt_dt,
|
||||
metrics=metrics,
|
||||
debug={"steps": self.get_debug_data(), "cost": self.get_debug_cost()},
|
||||
)
|
||||
return result
|
||||
|
||||
@property
|
||||
def newton_optimizer(self) -> NewtonOptBase:
|
||||
return self.optimizers[-1]
|
||||
|
||||
@property
|
||||
def particle_optimizer(self) -> ParticleOptBase:
|
||||
return self.optimizers[0]
|
||||
|
||||
@property
|
||||
def joint_names(self):
|
||||
return self.rollout_fn.cspace_config.joint_names
|
||||
|
||||
def _get_rollout_instances_from_optimizers(self) -> List[RolloutBase]:
|
||||
if self._opt_rollouts is None:
|
||||
self._opt_rollouts = []
|
||||
for i in self.optimizers:
|
||||
self._opt_rollouts.extend(i.get_all_rollout_instances())
|
||||
return self._opt_rollouts
|
||||
|
||||
def get_all_rollout_instances(self) -> List[RolloutBase]:
|
||||
if self._rollout_list is None:
|
||||
self._rollout_list = [
|
||||
self.safety_rollout
|
||||
] + self._get_rollout_instances_from_optimizers()
|
||||
return self._rollout_list
|
||||
Reference in New Issue
Block a user