194 lines
6.2 KiB
Python
194 lines
6.2 KiB
Python
#
|
|
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
# property and proprietary rights in and to this material, related
|
|
# documentation and any modifications thereto. Any use, reproduction,
|
|
# disclosure or distribution of this material and related documentation
|
|
# without an express license agreement from NVIDIA CORPORATION or
|
|
# its affiliates is strictly prohibited.
|
|
#
|
|
# Standard Library
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any, List, Optional
|
|
|
|
# Third Party
|
|
import torch
|
|
import torch.autograd.profiler as profiler
|
|
|
|
# CuRobo
|
|
from curobo.opt.newton.newton_base import NewtonOptBase
|
|
from curobo.opt.opt_base import Optimizer
|
|
from curobo.opt.particle.particle_opt_base import ParticleOptBase
|
|
from curobo.rollout.rollout_base import Goal, RolloutBase, RolloutMetrics
|
|
from curobo.types.robot import State
|
|
from curobo.util.logger import log_info
|
|
|
|
|
|
@dataclass
|
|
class WrapConfig:
|
|
safety_rollout: RolloutBase
|
|
optimizers: List[Optimizer]
|
|
compute_metrics: bool
|
|
use_cuda_graph_metrics: bool
|
|
sync_cuda_time: bool
|
|
|
|
|
|
@dataclass
|
|
class WrapResult:
|
|
action: State
|
|
solve_time: float
|
|
metrics: Optional[RolloutMetrics] = None
|
|
debug: Any = None
|
|
js_action: Optional[State] = None
|
|
raw_action: Optional[torch.Tensor] = None
|
|
|
|
def clone(self):
|
|
return WrapResult(
|
|
self.action.clone(), self.solve_time, self.metrics.clone(), debug=self.debug
|
|
)
|
|
|
|
|
|
class WrapBase(WrapConfig):
|
|
def __init__(self, config: Optional[WrapConfig] = None):
|
|
if config is not None:
|
|
WrapConfig.__init__(self, **vars(config))
|
|
self.n_problems = 1
|
|
self.opt_dt = 0
|
|
self._rollout_list = None
|
|
self._opt_rollouts = None
|
|
self._init_solver = False
|
|
|
|
def get_metrics(self, state: State, use_cuda_graph: bool = False) -> RolloutMetrics:
|
|
if use_cuda_graph:
|
|
log_info("Using cuda graph")
|
|
return self.safety_rollout.get_metrics_cuda_graph(state)
|
|
return self.safety_rollout.get_metrics(state)
|
|
|
|
def optimize(self, act_seq: torch.Tensor, shift_steps: int = 0) -> torch.Tensor:
|
|
for opt in self.optimizers:
|
|
act_seq = opt.optimize(act_seq, shift_steps)
|
|
return act_seq
|
|
|
|
def get_debug_data(self):
|
|
debug_list = []
|
|
for opt in self.optimizers:
|
|
debug_list.append(opt.debug)
|
|
return debug_list
|
|
|
|
def get_debug_cost(self):
|
|
debug_list = []
|
|
for opt in self.optimizers:
|
|
debug_list.append(opt.debug_cost)
|
|
return debug_list
|
|
|
|
def update_nproblems(self, n_problems):
|
|
if n_problems != self.n_problems:
|
|
self.n_problems = n_problems
|
|
for opt in self.optimizers:
|
|
opt.update_nproblems(self.n_problems)
|
|
|
|
def update_params(self, goal: Goal):
|
|
with profiler.record_function("wrap_base/safety/update_params"):
|
|
log_info("Updating safety params")
|
|
self.safety_rollout.update_params(goal)
|
|
log_info("Updating optimizer params")
|
|
for opt in self.optimizers:
|
|
opt.update_params(goal)
|
|
|
|
def get_init_act(self):
|
|
init_act = self.safety_rollout.get_init_action_seq()
|
|
return init_act
|
|
|
|
def reset(self):
|
|
for opt in self.optimizers:
|
|
opt.reset()
|
|
self.safety_rollout.reset()
|
|
|
|
def reset_seed(self):
|
|
self.safety_rollout.reset_seed()
|
|
for opt in self.optimizers:
|
|
opt.reset_seed()
|
|
|
|
def reset_cuda_graph(self):
|
|
self.safety_rollout.reset_cuda_graph()
|
|
for opt in self.optimizers:
|
|
opt.reset_cuda_graph()
|
|
self._init_solver = False
|
|
|
|
def reset_shape(self):
|
|
self.safety_rollout.reset_shape()
|
|
for opt in self.optimizers:
|
|
opt.reset_shape()
|
|
self._init_solver = False
|
|
|
|
@property
|
|
def rollout_fn(self):
|
|
return self.safety_rollout
|
|
|
|
def solve(self, goal: Goal, seed: Optional[torch.Tensor] = None):
|
|
metrics = None
|
|
|
|
filtered_state = self.safety_rollout.filter_robot_state(goal.current_state)
|
|
goal.current_state = filtered_state
|
|
self.update_params(goal)
|
|
if seed is None:
|
|
seed = self.get_init_act()
|
|
log_info("getting random seed")
|
|
else:
|
|
seed = seed.detach().clone()
|
|
start_time = time.time()
|
|
if not self._init_solver:
|
|
log_info("Solver was not initialized, warming up solver")
|
|
for _ in range(2):
|
|
act_seq = self.optimize(seed, shift_steps=0)
|
|
self._init_solver = True
|
|
act_seq = self.optimize(seed, shift_steps=0)
|
|
self.opt_dt = time.time() - start_time
|
|
|
|
act = self.safety_rollout.get_robot_command(
|
|
filtered_state, act_seq, state_idx=goal.batch_current_state_idx
|
|
)
|
|
|
|
if self.compute_metrics:
|
|
with profiler.record_function("wrap_base/compute_metrics"):
|
|
metrics = self.get_metrics(
|
|
act, self.use_cuda_graph_metrics
|
|
) # TODO: use cuda graph for metrics
|
|
|
|
result = WrapResult(
|
|
action=act,
|
|
solve_time=self.opt_dt,
|
|
metrics=metrics,
|
|
debug={"steps": self.get_debug_data(), "cost": self.get_debug_cost()},
|
|
raw_action=act_seq,
|
|
)
|
|
return result
|
|
|
|
@property
|
|
def newton_optimizer(self) -> NewtonOptBase:
|
|
return self.optimizers[-1]
|
|
|
|
@property
|
|
def particle_optimizer(self) -> ParticleOptBase:
|
|
return self.optimizers[0]
|
|
|
|
@property
|
|
def joint_names(self):
|
|
return self.rollout_fn.cspace_config.joint_names
|
|
|
|
def _get_rollout_instances_from_optimizers(self) -> List[RolloutBase]:
|
|
if self._opt_rollouts is None:
|
|
self._opt_rollouts = []
|
|
for i in self.optimizers:
|
|
self._opt_rollouts.extend(i.get_all_rollout_instances())
|
|
return self._opt_rollouts
|
|
|
|
def get_all_rollout_instances(self) -> List[RolloutBase]:
|
|
if self._rollout_list is None:
|
|
self._rollout_list = [
|
|
self.safety_rollout
|
|
] + self._get_rollout_instances_from_optimizers()
|
|
return self._rollout_list
|