release repository

This commit is contained in:
Balakumar Sundaralingam
2023-10-26 04:17:19 -07:00
commit 07e6ccfc91
287 changed files with 70659 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
"""
This module contains rollout classes
"""

View File

@@ -0,0 +1,751 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from abc import abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
# Third Party
import torch
import torch.autograd.profiler as profiler
# CuRobo
from curobo.geom.sdf.utils import create_collision_checker
from curobo.geom.sdf.world import WorldCollision, WorldCollisionConfig
from curobo.geom.types import WorldConfig
from curobo.rollout.cost.bound_cost import BoundCost, BoundCostConfig
from curobo.rollout.cost.dist_cost import DistCost, DistCostConfig
from curobo.rollout.cost.manipulability_cost import ManipulabilityCost, ManipulabilityCostConfig
from curobo.rollout.cost.primitive_collision_cost import (
PrimitiveCollisionCost,
PrimitiveCollisionCostConfig,
)
from curobo.rollout.cost.self_collision_cost import SelfCollisionCost, SelfCollisionCostConfig
from curobo.rollout.cost.stop_cost import StopCost, StopCostConfig
from curobo.rollout.dynamics_model.kinematic_model import (
KinematicModel,
KinematicModelConfig,
KinematicModelState,
)
from curobo.rollout.rollout_base import Goal, RolloutBase, RolloutConfig, RolloutMetrics, Trajectory
from curobo.types.base import TensorDeviceType
from curobo.types.robot import CSpaceConfig, RobotConfig
from curobo.types.state import JointState
from curobo.util.logger import log_info, log_warn
from curobo.util.tensor_util import cat_sum
@dataclass
class ArmCostConfig:
bound_cfg: Optional[BoundCostConfig] = None
null_space_cfg: Optional[DistCostConfig] = None
manipulability_cfg: Optional[ManipulabilityCostConfig] = None
stop_cfg: Optional[StopCostConfig] = None
self_collision_cfg: Optional[SelfCollisionCostConfig] = None
primitive_collision_cfg: Optional[PrimitiveCollisionCostConfig] = None
@staticmethod
def _get_base_keys():
k_list = {
"null_space_cfg": DistCostConfig,
"manipulability_cfg": ManipulabilityCostConfig,
"stop_cfg": StopCostConfig,
"self_collision_cfg": SelfCollisionCostConfig,
"bound_cfg": BoundCostConfig,
}
return k_list
@staticmethod
def from_dict(
data_dict: Dict,
robot_config: RobotConfig,
world_coll_checker: Optional[WorldCollision] = None,
tensor_args: TensorDeviceType = TensorDeviceType(),
):
k_list = ArmCostConfig._get_base_keys()
data = ArmCostConfig._get_formatted_dict(
data_dict,
k_list,
robot_config,
world_coll_checker=world_coll_checker,
tensor_args=tensor_args,
)
return ArmCostConfig(**data)
@staticmethod
def _get_formatted_dict(
data_dict: Dict,
cost_key_list: Dict,
robot_config: RobotConfig,
world_coll_checker: Optional[WorldCollision] = None,
tensor_args: TensorDeviceType = TensorDeviceType(),
):
data = {}
for k in cost_key_list:
if k in data_dict:
data[k] = cost_key_list[k](**data_dict[k], tensor_args=tensor_args)
if "primitive_collision_cfg" in data_dict and world_coll_checker is not None:
data["primitive_collision_cfg"] = PrimitiveCollisionCostConfig(
**data_dict["primitive_collision_cfg"],
world_coll_checker=world_coll_checker,
tensor_args=tensor_args
)
return data
@dataclass
class ArmBaseConfig(RolloutConfig):
model_cfg: KinematicModelConfig
cost_cfg: ArmCostConfig
constraint_cfg: ArmCostConfig
convergence_cfg: ArmCostConfig
world_coll_checker: Optional[WorldCollision] = None
@staticmethod
def model_from_dict(
model_data_dict: Dict,
robot_cfg: RobotConfig,
tensor_args: TensorDeviceType = TensorDeviceType(),
):
return KinematicModelConfig.from_dict(model_data_dict, robot_cfg, tensor_args=tensor_args)
@staticmethod
def cost_from_dict(
cost_data_dict: Dict,
robot_cfg: RobotConfig,
world_coll_checker: Optional[WorldCollision] = None,
tensor_args: TensorDeviceType = TensorDeviceType(),
):
return ArmCostConfig.from_dict(
cost_data_dict,
robot_cfg,
world_coll_checker=world_coll_checker,
tensor_args=tensor_args,
)
@staticmethod
def world_coll_checker_from_dict(
world_coll_checker_dict: Optional[Dict] = None,
world_model_dict: Optional[Union[WorldConfig, Dict]] = None,
world_coll_checker: Optional[WorldCollision] = None,
tensor_args: TensorDeviceType = TensorDeviceType(),
):
# TODO: Check which type of collision checker and load that.
if (
world_coll_checker is None
and world_model_dict is not None
and world_coll_checker_dict is not None
):
world_coll_cfg = WorldCollisionConfig.load_from_dict(
world_coll_checker_dict, world_model_dict, tensor_args
)
world_coll_checker = create_collision_checker(world_coll_cfg)
else:
log_info("*******USING EXISTING COLLISION CHECKER***********")
return world_coll_checker
@classmethod
@profiler.record_function("arm_base_config/from_dict")
def from_dict(
cls,
robot_cfg: Union[Dict, RobotConfig],
model_data_dict: Dict,
cost_data_dict: Dict,
constraint_data_dict: Dict,
convergence_data_dict: Dict,
world_coll_checker_dict: Optional[Dict] = None,
world_model_dict: Optional[Dict] = None,
world_coll_checker: Optional[WorldCollision] = None,
tensor_args: TensorDeviceType = TensorDeviceType(),
):
"""Create ArmBase class from dictionary
NOTE: We declare this as a classmethod to allow for derived classes to use it.
Args:
robot_cfg (Union[Dict, RobotConfig]): _description_
model_data_dict (Dict): _description_
cost_data_dict (Dict): _description_
constraint_data_dict (Dict): _description_
convergence_data_dict (Dict): _description_
world_coll_checker_dict (Optional[Dict], optional): _description_. Defaults to None.
world_model_dict (Optional[Dict], optional): _description_. Defaults to None.
world_coll_checker (Optional[WorldCollision], optional): _description_. Defaults to None.
tensor_args (TensorDeviceType, optional): _description_. Defaults to TensorDeviceType().
Returns:
_type_: _description_
"""
if isinstance(robot_cfg, dict):
robot_cfg = RobotConfig.from_dict(robot_cfg, tensor_args)
world_coll_checker = cls.world_coll_checker_from_dict(
world_coll_checker_dict, world_model_dict, world_coll_checker, tensor_args
)
model = cls.model_from_dict(model_data_dict, robot_cfg, tensor_args=tensor_args)
cost = cls.cost_from_dict(
cost_data_dict,
robot_cfg,
world_coll_checker=world_coll_checker,
tensor_args=tensor_args,
)
constraint = cls.cost_from_dict(
constraint_data_dict,
robot_cfg,
world_coll_checker=world_coll_checker,
tensor_args=tensor_args,
)
convergence = cls.cost_from_dict(
convergence_data_dict,
robot_cfg,
world_coll_checker=world_coll_checker,
tensor_args=tensor_args,
)
return cls(
model_cfg=model,
cost_cfg=cost,
constraint_cfg=constraint,
convergence_cfg=convergence,
world_coll_checker=world_coll_checker,
tensor_args=tensor_args,
)
class ArmBase(RolloutBase, ArmBaseConfig):
"""
This rollout function is for reaching a cartesian pose for a robot
"""
@profiler.record_function("arm_base/init")
def __init__(self, config: Optional[ArmBaseConfig] = None):
if config is not None:
ArmBaseConfig.__init__(self, **vars(config))
RolloutBase.__init__(self)
self._init_after_config_load()
@profiler.record_function("arm_base/init_after_config_load")
def _init_after_config_load(self):
# self.current_state = None
# self.retract_state = None
self._goal_buffer = Goal()
self._goal_idx_update = True
# Create the dynamical system used for rollouts
self.dynamics_model = KinematicModel(self.model_cfg)
self.n_dofs = self.dynamics_model.n_dofs
self.traj_dt = self.dynamics_model.traj_dt
if self.cost_cfg.bound_cfg is not None:
self.cost_cfg.bound_cfg.set_bounds(
self.dynamics_model.get_state_bounds(),
teleport_mode=self.dynamics_model.teleport_mode,
)
self.cost_cfg.bound_cfg.cspace_distance_weight = (
self.dynamics_model.cspace_distance_weight
)
self.cost_cfg.bound_cfg.state_finite_difference_mode = (
self.dynamics_model.state_finite_difference_mode
)
self.cost_cfg.bound_cfg.update_vec_weight(self.dynamics_model.null_space_weight)
if self.cost_cfg.null_space_cfg is not None:
self.cost_cfg.bound_cfg.null_space_weight = self.cost_cfg.null_space_cfg.weight
log_warn(
"null space cost is deprecated, use null_space_weight in bound cost instead"
)
self.bound_cost = BoundCost(self.cost_cfg.bound_cfg)
if self.cost_cfg.manipulability_cfg is not None:
self.manipulability_cost = ManipulabilityCost(self.cost_cfg.manipulability_cfg)
if self.cost_cfg.stop_cfg is not None:
self.cost_cfg.stop_cfg.horizon = self.dynamics_model.horizon
self.cost_cfg.stop_cfg.dt_traj_params = self.dynamics_model.dt_traj_params
self.stop_cost = StopCost(self.cost_cfg.stop_cfg)
self._goal_buffer.retract_state = self.retract_state
if self.cost_cfg.primitive_collision_cfg is not None:
self.primitive_collision_cost = PrimitiveCollisionCost(
self.cost_cfg.primitive_collision_cfg
)
if self.dynamics_model.robot_model.total_spheres == 0:
self.primitive_collision_cost.disable_cost()
if self.cost_cfg.self_collision_cfg is not None:
self.cost_cfg.self_collision_cfg.self_collision_kin_config = (
self.dynamics_model.robot_model.get_self_collision_config()
)
self.robot_self_collision_cost = SelfCollisionCost(self.cost_cfg.self_collision_cfg)
if self.dynamics_model.robot_model.total_spheres == 0:
self.robot_self_collision_cost.disable_cost()
# setup constraint terms:
if self.constraint_cfg.primitive_collision_cfg is not None:
self.primitive_collision_constraint = PrimitiveCollisionCost(
self.constraint_cfg.primitive_collision_cfg
)
if self.dynamics_model.robot_model.total_spheres == 0:
self.primitive_collision_constraint.disable_cost()
if self.constraint_cfg.self_collision_cfg is not None:
self.constraint_cfg.self_collision_cfg.self_collision_kin_config = (
self.dynamics_model.robot_model.get_self_collision_config()
)
self.robot_self_collision_constraint = SelfCollisionCost(
self.constraint_cfg.self_collision_cfg
)
if self.dynamics_model.robot_model.total_spheres == 0:
self.robot_self_collision_constraint.disable_cost()
self.constraint_cfg.bound_cfg.set_bounds(
self.dynamics_model.get_state_bounds(), teleport_mode=self.dynamics_model.teleport_mode
)
self.constraint_cfg.bound_cfg.cspace_distance_weight = (
self.dynamics_model.cspace_distance_weight
)
self.cost_cfg.bound_cfg.state_finite_difference_mode = (
self.dynamics_model.state_finite_difference_mode
)
self.bound_constraint = BoundCost(self.constraint_cfg.bound_cfg)
if self.convergence_cfg.null_space_cfg is not None:
self.null_convergence = DistCost(self.convergence_cfg.null_space_cfg)
# set start state:
start_state = torch.randn((1, self.dynamics_model.d_state), **vars(self.tensor_args))
self._start_state = JointState(
position=start_state[:, : self.dynamics_model.d_action],
velocity=start_state[:, : self.dynamics_model.d_action],
acceleration=start_state[:, : self.dynamics_model.d_action],
)
self.update_cost_dt(self.dynamics_model.dt_traj_params.base_dt)
return RolloutBase._init_after_config_load(self)
def cost_fn(self, state: KinematicModelState, action_batch=None, return_list=False):
# ee_pos_batch, ee_rot_batch = state_dict["ee_pos_seq"], state_dict["ee_rot_seq"]
state_batch = state.state_seq
cost_list = []
# compute state bound cost:
if self.bound_cost.enabled:
with profiler.record_function("cost/bound"):
c = self.bound_cost.forward(
state_batch,
self._goal_buffer.retract_state,
self._goal_buffer.batch_retract_state_idx,
)
cost_list.append(c)
if self.cost_cfg.manipulability_cfg is not None and self.manipulability_cost.enabled:
raise NotImplementedError("Manipulability Cost is not implemented")
if self.cost_cfg.stop_cfg is not None and self.stop_cost.enabled:
st_cost = self.stop_cost.forward(state_batch.velocity)
cost_list.append(st_cost)
if self.cost_cfg.self_collision_cfg is not None and self.robot_self_collision_cost.enabled:
with profiler.record_function("cost/self_collision"):
coll_cost = self.robot_self_collision_cost.forward(state.robot_spheres)
# cost += coll_cost
cost_list.append(coll_cost)
if (
self.cost_cfg.primitive_collision_cfg is not None
and self.primitive_collision_cost.enabled
):
with profiler.record_function("cost/collision"):
coll_cost = self.primitive_collision_cost.forward(
state.robot_spheres,
env_query_idx=self._goal_buffer.batch_world_idx,
)
cost_list.append(coll_cost)
if return_list:
return cost_list
cost = cat_sum(cost_list)
return cost
def constraint_fn(
self,
state: KinematicModelState,
out_metrics: Optional[RolloutMetrics] = None,
use_batch_env: bool = True,
) -> RolloutMetrics:
# setup constraint terms:
constraint = self.bound_constraint.forward(state.state_seq)
constraint_list = [constraint]
if (
self.constraint_cfg.primitive_collision_cfg is not None
and self.primitive_collision_constraint.enabled
):
if use_batch_env and self._goal_buffer.batch_world_idx is not None:
coll_constraint = self.primitive_collision_constraint.forward(
state.robot_spheres,
env_query_idx=self._goal_buffer.batch_world_idx,
)
else:
coll_constraint = self.primitive_collision_constraint.forward(
state.robot_spheres, env_query_idx=None
)
constraint_list.append(coll_constraint)
if (
self.constraint_cfg.self_collision_cfg is not None
and self.robot_self_collision_constraint.enabled
):
self_constraint = self.robot_self_collision_constraint.forward(state.robot_spheres)
constraint_list.append(self_constraint)
constraint = cat_sum(constraint_list)
feasible = constraint == 0.0
if out_metrics is None:
out_metrics = RolloutMetrics()
out_metrics.feasible = feasible
out_metrics.constraint = constraint
return out_metrics
def get_metrics(self, state: Union[JointState, KinematicModelState]):
"""Compute metrics given state
#TODO: Currently does not compute velocity and acceleration costs.
Args:
state (Union[JointState, URDFModelState]): _description_
Returns:
_type_: _description_
"""
if isinstance(state, JointState):
state = self._get_augmented_state(state)
out_metrics = self.constraint_fn(state)
out_metrics.state = state
out_metrics = self.convergence_fn(state, out_metrics)
return out_metrics
def get_metrics_cuda_graph(self, state: JointState):
"""Use a CUDA Graph to compute metrics
Args:
state: _description_
Raises:
ValueError: _description_
Returns:
_description_
"""
if not self._metrics_cuda_graph_init:
# create new cuda graph for metrics:
self._cu_metrics_state_in = state.detach().clone()
s = torch.cuda.Stream(device=self.tensor_args.device)
s.wait_stream(torch.cuda.current_stream(device=self.tensor_args.device))
with torch.cuda.stream(s):
for _ in range(3):
self._cu_out_metrics = self.get_metrics(self._cu_metrics_state_in)
torch.cuda.current_stream(device=self.tensor_args.device).wait_stream(s)
self.cu_metrics_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cu_metrics_graph, stream=s):
self._cu_out_metrics = self.get_metrics(self._cu_metrics_state_in)
self._metrics_cuda_graph_init = True
self._cu_metrics_state_in.copy_(state)
self.cu_metrics_graph.replay()
out_metrics = self._cu_out_metrics
return out_metrics.clone()
@abstractmethod
def convergence_fn(
self, state: KinematicModelState, out_metrics: Optional[RolloutMetrics] = None
):
if out_metrics is None:
out_metrics = RolloutMetrics()
if (
self.convergence_cfg.null_space_cfg is not None
and self.null_convergence.enabled
and self._goal_buffer.batch_retract_state_idx is not None
):
out_metrics.cost = self.null_convergence.forward_target_idx(
self._goal_buffer.retract_state,
state.state_seq.position,
self._goal_buffer.batch_retract_state_idx,
)
return out_metrics
def _get_augmented_state(self, state: JointState) -> KinematicModelState:
aug_state = self.compute_kinematics(state)
if len(aug_state.state_seq.position.shape) == 2:
aug_state.state_seq = aug_state.state_seq.unsqueeze(1)
aug_state.ee_pos_seq = aug_state.ee_pos_seq.unsqueeze(1)
aug_state.ee_quat_seq = aug_state.ee_quat_seq.unsqueeze(1)
if aug_state.lin_jac_seq is not None:
aug_state.lin_jac_seq = aug_state.lin_jac_seq.unsqueeze(1)
if aug_state.ang_jac_seq is not None:
aug_state.ang_jac_seq = aug_state.ang_jac_seq.unsqueeze(1)
aug_state.robot_spheres = aug_state.robot_spheres.unsqueeze(1)
aug_state.link_pos_seq = aug_state.link_pos_seq.unsqueeze(1)
aug_state.link_quat_seq = aug_state.link_quat_seq.unsqueeze(1)
return aug_state
def compute_kinematics(self, state: JointState) -> KinematicModelState:
# assume input is joint state?
h = 0
current_state = state # .detach().clone()
if len(current_state.position.shape) == 1:
current_state = current_state.unsqueeze(0)
q = current_state.position
if len(q.shape) == 3:
b, h, _ = q.shape
q = q.view(b * h, -1)
(
ee_pos_seq,
ee_rot_seq,
lin_jac_seq,
ang_jac_seq,
link_pos_seq,
link_rot_seq,
link_spheres,
) = self.dynamics_model.robot_model.forward(q)
if h != 0:
ee_pos_seq = ee_pos_seq.view(b, h, 3)
ee_rot_seq = ee_rot_seq.view(b, h, 4)
if lin_jac_seq is not None:
lin_jac_seq = lin_jac_seq.view(b, h, 3, self.n_dofs)
if ang_jac_seq is not None:
ang_jac_seq = ang_jac_seq.view(b, h, 3, self.n_dofs)
link_spheres = link_spheres.view(b, h, link_spheres.shape[-2], link_spheres.shape[-1])
link_pos_seq = link_pos_seq.view(b, h, -1, 3)
link_rot_seq = link_rot_seq.view(b, h, -1, 4)
state = KinematicModelState(
current_state,
ee_pos_seq,
ee_rot_seq,
link_spheres,
link_pos_seq,
link_rot_seq,
lin_jac_seq,
ang_jac_seq,
link_names=self.kinematics.link_names,
)
return state
def rollout_constraint(
self, act_seq: torch.Tensor, use_batch_env: bool = True
) -> RolloutMetrics:
state = self.dynamics_model.forward(self.start_state, act_seq)
metrics = self.constraint_fn(state, use_batch_env=use_batch_env)
return metrics
def rollout_constraint_cuda_graph(self, act_seq: torch.Tensor, use_batch_env: bool = True):
# TODO: move this to RolloutBase
if not self._rollout_constraint_cuda_graph_init:
# create new cuda graph for metrics:
self._cu_rollout_constraint_act_in = act_seq.clone()
s = torch.cuda.Stream(device=self.tensor_args.device)
s.wait_stream(torch.cuda.current_stream(device=self.tensor_args.device))
with torch.cuda.stream(s):
for _ in range(3):
state = self.dynamics_model.forward(self.start_state, act_seq)
self._cu_rollout_constraint_out_metrics = self.constraint_fn(
state, use_batch_env=use_batch_env
)
torch.cuda.current_stream(device=self.tensor_args.device).wait_stream(s)
self.cu_rollout_constraint_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cu_rollout_constraint_graph, stream=s):
state = self.dynamics_model.forward(self.start_state, act_seq)
self._cu_rollout_constraint_out_metrics = self.constraint_fn(
state, use_batch_env=use_batch_env
)
self._rollout_constraint_cuda_graph_init = True
self._cu_rollout_constraint_act_in.copy_(act_seq)
self.cu_rollout_constraint_graph.replay()
out_metrics = self._cu_rollout_constraint_out_metrics
return out_metrics.clone()
def rollout_fn(self, act_seq) -> Trajectory:
"""
Return sequence of costs and states encountered
by simulating a batch of action sequences
Parameters
----------
action_seq: torch.Tensor [num_particles, horizon, d_act]
"""
# print(act_seq.shape, self._goal_buffer.batch_current_state_idx)
if self.start_state is None:
raise ValueError("start_state is not set in rollout")
with profiler.record_function("robot_model/rollout"):
state = self.dynamics_model.forward(
self.start_state, act_seq, self._goal_buffer.batch_current_state_idx
)
with profiler.record_function("cost/all"):
cost_seq = self.cost_fn(state, act_seq)
sim_trajs = Trajectory(actions=act_seq, costs=cost_seq, state=state)
return sim_trajs
def update_params(self, goal: Goal):
"""
Updates the goal targets for the cost functions.
"""
with profiler.record_function("arm_base/update_params"):
self._goal_buffer.copy_(
goal, update_idx_buffers=self._goal_idx_update
) # TODO: convert this to a reference to avoid extra copy
# self._goal_buffer.copy_(goal, update_idx_buffers=True) # TODO: convert this to a reference to avoid extra copy
# TODO: move start state also inside Goal instance
if goal.current_state is not None:
if self.start_state is None:
self.start_state = goal.current_state.clone()
else:
self.start_state = self.start_state.copy_(goal.current_state)
self.batch_size = goal.batch
return True
def get_ee_pose(self, current_state):
current_state = current_state.to(**self.tensor_args)
(ee_pos_batch, ee_quat_batch) = self.dynamics_model.robot_model.forward(
current_state[:, : self.dynamics_model.n_dofs]
)[0:2]
state = KinematicModelState(current_state, ee_pos_batch, ee_quat_batch)
return state
def current_cost(self, current_state: JointState, no_coll=False, return_state=True, **kwargs):
state = self._get_augmented_state(current_state)
if "horizon_cost" not in kwargs:
kwargs["horizon_cost"] = False
cost = self.cost_fn(state, None, no_coll=no_coll, **kwargs)
if return_state:
return cost, state
else:
return cost
def filter_robot_state(self, current_state: JointState) -> JointState:
return self.dynamics_model.filter_robot_state(current_state)
def get_robot_command(
self,
current_state: JointState,
act_seq: torch.Tensor,
shift_steps: int = 1,
state_idx: Optional[torch.Tensor] = None,
) -> JointState:
return self.dynamics_model.get_robot_command(
current_state,
act_seq,
shift_steps=shift_steps,
state_idx=state_idx,
)
def reset(self):
self.dynamics_model.state_filter.reset()
super().reset()
@property
def d_action(self):
return self.dynamics_model.d_action
@property
def action_bound_lows(self):
return self.dynamics_model.action_bound_lows
@property
def action_bound_highs(self):
return self.dynamics_model.action_bound_highs
@property
def state_bounds(self) -> Dict[str, List[float]]:
return self.dynamics_model.get_state_bounds()
@property
def dt(self):
return self.dynamics_model.dt
@property
def horizon(self):
return self.dynamics_model.horizon
def get_init_action_seq(self) -> torch.Tensor:
act_seq = self.dynamics_model.init_action_mean.unsqueeze(0).repeat(self.batch_size, 1, 1)
return act_seq
def reset_cuda_graph(self):
self._goal_idx_update = True
super().reset_cuda_graph()
def get_action_from_state(self, state: JointState):
return self.dynamics_model.get_action_from_state(state)
def get_state_from_action(
self,
start_state: JointState,
act_seq: torch.Tensor,
state_idx: Optional[torch.Tensor] = None,
):
return self.dynamics_model.get_state_from_action(start_state, act_seq, state_idx)
@property
def kinematics(self):
return self.dynamics_model.robot_model
@property
def cspace_config(self) -> CSpaceConfig:
return self.dynamics_model.robot_model.kinematics_config.cspace
def get_full_dof_from_solution(self, q_js: JointState) -> JointState:
"""This function will all the dof that are locked during optimization.
Args:
q_sol: _description_
Returns:
_description_
"""
if self.kinematics.lock_jointstate is None:
return q_js
all_joint_names = self.kinematics.all_articulated_joint_names
lock_joint_state = self.kinematics.lock_jointstate
new_js = q_js.get_augmented_joint_state(all_joint_names, lock_joint_state)
return new_js
@property
def joint_names(self) -> List[str]:
return self.kinematics.joint_names
@property
def retract_state(self):
return self.dynamics_model.retract_config
def update_traj_dt(
self,
dt: Union[float, torch.Tensor],
base_dt: Optional[float] = None,
max_dt: Optional[float] = None,
base_ratio: Optional[float] = None,
):
self.dynamics_model.update_traj_dt(dt, base_dt, max_dt, base_ratio)
self.update_cost_dt(dt)
def update_cost_dt(self, dt: float):
# scale any temporal costs by dt:
self.bound_cost.update_dt(dt)
if self.cost_cfg.primitive_collision_cfg is not None:
self.primitive_collision_cost.update_dt(dt)

View File

@@ -0,0 +1,403 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from typing import Dict, Optional
# Third Party
import torch
import torch.autograd.profiler as profiler
# CuRobo
from curobo.geom.sdf.world import WorldCollision
from curobo.rollout.cost.cost_base import CostConfig
from curobo.rollout.cost.dist_cost import DistCost, DistCostConfig
from curobo.rollout.cost.pose_cost import PoseCost, PoseCostConfig
from curobo.rollout.cost.straight_line_cost import StraightLineCost
from curobo.rollout.cost.zero_cost import ZeroCost
from curobo.rollout.dynamics_model.kinematic_model import KinematicModelState
from curobo.rollout.rollout_base import Goal, RolloutMetrics
from curobo.types.base import TensorDeviceType
from curobo.types.robot import RobotConfig
from curobo.types.tensor import T_BValue_float
from curobo.util.helpers import list_idx_if_not_none
from curobo.util.tensor_util import cat_max, cat_sum
# Local Folder
from .arm_base import ArmBase, ArmBaseConfig, ArmCostConfig
@dataclass
class ArmReacherMetrics(RolloutMetrics):
cspace_error: Optional[T_BValue_float] = None
position_error: Optional[T_BValue_float] = None
rotation_error: Optional[T_BValue_float] = None
pose_error: Optional[T_BValue_float] = None
def __getitem__(self, idx):
d_list = [
self.cost,
self.constraint,
self.feasible,
self.state,
self.cspace_error,
self.position_error,
self.rotation_error,
self.pose_error,
]
idx_vals = list_idx_if_not_none(d_list, idx)
return ArmReacherMetrics(*idx_vals)
def clone(self, clone_state=False):
if clone_state:
raise NotImplementedError()
return ArmReacherMetrics(
cost=None if self.cost is None else self.cost.clone(),
constraint=None if self.constraint is None else self.constraint.clone(),
feasible=None if self.feasible is None else self.feasible.clone(),
state=None if self.state is None else self.state,
cspace_error=None if self.cspace_error is None else self.cspace_error,
position_error=None if self.position_error is None else self.position_error,
rotation_error=None if self.rotation_error is None else self.rotation_error,
pose_error=None if self.pose_error is None else self.pose_error,
)
@dataclass
class ArmReacherCostConfig(ArmCostConfig):
pose_cfg: Optional[PoseCostConfig] = None
cspace_cfg: Optional[DistCostConfig] = None
straight_line_cfg: Optional[CostConfig] = None
zero_acc_cfg: Optional[CostConfig] = None
zero_vel_cfg: Optional[CostConfig] = None
zero_jerk_cfg: Optional[CostConfig] = None
@staticmethod
def _get_base_keys():
base_k = ArmCostConfig._get_base_keys()
# add new cost terms:
new_k = {
"pose_cfg": PoseCostConfig,
"cspace_cfg": DistCostConfig,
"straight_line_cfg": CostConfig,
"zero_acc_cfg": CostConfig,
"zero_vel_cfg": CostConfig,
"zero_jerk_cfg": CostConfig,
}
new_k.update(base_k)
return new_k
@staticmethod
def from_dict(
data_dict: Dict,
robot_cfg: RobotConfig,
world_coll_checker: Optional[WorldCollision] = None,
tensor_args: TensorDeviceType = TensorDeviceType(),
):
k_list = ArmReacherCostConfig._get_base_keys()
data = ArmCostConfig._get_formatted_dict(
data_dict,
k_list,
robot_cfg,
world_coll_checker=world_coll_checker,
tensor_args=tensor_args,
)
return ArmReacherCostConfig(**data)
@dataclass
class ArmReacherConfig(ArmBaseConfig):
cost_cfg: ArmReacherCostConfig
constraint_cfg: ArmReacherCostConfig
convergence_cfg: ArmReacherCostConfig
@staticmethod
def cost_from_dict(
cost_data_dict: Dict,
robot_cfg: RobotConfig,
world_coll_checker: Optional[WorldCollision] = None,
tensor_args: TensorDeviceType = TensorDeviceType(),
):
return ArmReacherCostConfig.from_dict(
cost_data_dict,
robot_cfg,
world_coll_checker=world_coll_checker,
tensor_args=tensor_args,
)
@torch.jit.script
def _compute_g_dist_jit(rot_err_norm, goal_dist):
# goal_cost = goal_cost.view(cost.shape)
# rot_err_norm = rot_err_norm.view(cost.shape)
# goal_dist = goal_dist.view(cost.shape)
g_dist = goal_dist.unsqueeze(-1) + 10.0 * rot_err_norm.unsqueeze(-1)
return g_dist
class ArmReacher(ArmBase, ArmReacherConfig):
"""
.. inheritance-diagram:: curobo.rollout.arm_reacher.ArmReacher
"""
@profiler.record_function("arm_reacher/init")
def __init__(self, config: Optional[ArmReacherConfig] = None):
if config is not None:
ArmReacherConfig.__init__(self, **vars(config))
ArmBase.__init__(self)
# self.goal_state = None
# self.goal_ee_pos = None
# self.goal_ee_rot = None
# self.goal_ee_quat = None
self._compute_g_dist = False
self._n_goalset = 1
if self.cost_cfg.cspace_cfg is not None:
# self.cost_cfg.cspace_cfg.update_vec_weight(self.dynamics_model.cspace_distance_weight)
self.dist_cost = DistCost(self.cost_cfg.cspace_cfg)
if self.cost_cfg.pose_cfg is not None:
self.goal_cost = PoseCost(self.cost_cfg.pose_cfg)
self._link_pose_costs = {}
for i in self.kinematics.link_names:
if i != self.kinematics.ee_link:
self._link_pose_costs[i] = PoseCost(self.cost_cfg.pose_cfg)
if self.cost_cfg.straight_line_cfg is not None:
self.straight_line_cost = StraightLineCost(self.cost_cfg.straight_line_cfg)
if self.cost_cfg.zero_vel_cfg is not None:
self.zero_vel_cost = ZeroCost(self.cost_cfg.zero_vel_cfg)
self._max_vel = self.state_bounds["velocity"][1]
if self.zero_vel_cost.hinge_value is not None:
self._compute_g_dist = True
if self.cost_cfg.zero_acc_cfg is not None:
self.zero_acc_cost = ZeroCost(self.cost_cfg.zero_acc_cfg)
self._max_vel = self.state_bounds["velocity"][1]
if self.zero_acc_cost.hinge_value is not None:
self._compute_g_dist = True
if self.cost_cfg.zero_jerk_cfg is not None:
self.zero_jerk_cost = ZeroCost(self.cost_cfg.zero_jerk_cfg)
self._max_vel = self.state_bounds["velocity"][1]
if self.zero_jerk_cost.hinge_value is not None:
self._compute_g_dist = True
self.z_tensor = torch.tensor(
0, device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
if self.convergence_cfg.pose_cfg is not None:
self.pose_convergence = PoseCost(self.convergence_cfg.pose_cfg)
self._link_pose_convergence = {}
for i in self.kinematics.link_names:
if i != self.kinematics.ee_link:
self._link_pose_convergence[i] = PoseCost(self.convergence_cfg.pose_cfg)
if self.convergence_cfg.cspace_cfg is not None:
self.cspace_convergence = DistCost(self.convergence_cfg.cspace_cfg)
# check if g_dist is required in any of the cost terms:
self.update_params(Goal(current_state=self._start_state))
def cost_fn(self, state: KinematicModelState, action_batch=None):
"""
Compute cost given that state dictionary and actions
:class:`curobo.rollout.cost.PoseCost`
:class:`curobo.rollout.cost.DistCost`
"""
state_batch = state.state_seq
with profiler.record_function("cost/base"):
cost_list = super(ArmReacher, self).cost_fn(state, action_batch, return_list=True)
ee_pos_batch, ee_quat_batch = state.ee_pos_seq, state.ee_quat_seq
g_dist = None
with profiler.record_function("cost/pose"):
if (
self._goal_buffer.goal_pose.position is not None
and self.cost_cfg.pose_cfg is not None
and self.goal_cost.enabled
):
if self._compute_g_dist:
goal_cost, rot_err_norm, goal_dist = self.goal_cost.forward_out_distance(
ee_pos_batch,
ee_quat_batch,
self._goal_buffer,
)
g_dist = _compute_g_dist_jit(rot_err_norm, goal_dist)
else:
goal_cost = self.goal_cost.forward(
ee_pos_batch, ee_quat_batch, self._goal_buffer
)
cost_list.append(goal_cost)
with profiler.record_function("cost/link_poses"):
if self._goal_buffer.links_goal_pose is not None and self.cost_cfg.pose_cfg is not None:
link_poses = state.link_pose
for k in self._goal_buffer.links_goal_pose.keys():
if k != self.kinematics.ee_link:
current_fn = self._link_pose_costs[k]
if current_fn.enabled:
# get link pose
current_pose = link_poses[k]
current_pos = current_pose.position
current_quat = current_pose.quaternion
c = current_fn.forward(current_pos, current_quat, self._goal_buffer, k)
cost_list.append(c)
if (
self._goal_buffer.goal_state is not None
and self.cost_cfg.cspace_cfg is not None
and self.dist_cost.enabled
):
joint_cost = self.dist_cost.forward_target_idx(
self._goal_buffer.goal_state.position,
state_batch.position,
self._goal_buffer.batch_goal_state_idx,
)
cost_list.append(joint_cost)
if self.cost_cfg.straight_line_cfg is not None and self.straight_line_cost.enabled:
st_cost = self.straight_line_cost.forward(ee_pos_batch)
cost_list.append(st_cost)
if (
self.cost_cfg.zero_acc_cfg is not None
and self.zero_acc_cost.enabled
# and g_dist is not None
):
z_acc = self.zero_acc_cost.forward(
state_batch.acceleration,
g_dist,
)
# cost += z_acc
cost_list.append(z_acc)
# print(self.cost_cfg.zero_jerk_cfg)
if (
self.cost_cfg.zero_jerk_cfg is not None
and self.zero_jerk_cost.enabled
# and g_dist is not None
):
# jerk = self.dynamics_model._aux_matrix @ state_batch.acceleration
z_jerk = self.zero_jerk_cost.forward(
state_batch.jerk,
g_dist,
)
cost_list.append(z_jerk)
# cost += z_jerk
if (
self.cost_cfg.zero_vel_cfg is not None
and self.zero_vel_cost.enabled
# and g_dist is not None
):
z_vel = self.zero_vel_cost.forward(
state_batch.velocity,
g_dist,
)
# cost += z_vel
# print(z_vel.shape)
cost_list.append(z_vel)
cost = cat_sum(cost_list)
return cost
def convergence_fn(
self, state: KinematicModelState, out_metrics: Optional[ArmReacherMetrics] = None
) -> ArmReacherMetrics:
if out_metrics is None:
out_metrics = ArmReacherMetrics()
if not isinstance(out_metrics, ArmReacherMetrics):
out_metrics = ArmReacherMetrics(**vars(out_metrics))
# print(self._goal_buffer.batch_retract_state_idx)
out_metrics = super(ArmReacher, self).convergence_fn(state, out_metrics)
# compute error with pose?
if (
self._goal_buffer.goal_pose.position is not None
and self.convergence_cfg.pose_cfg is not None
):
(
out_metrics.pose_error,
out_metrics.rotation_error,
out_metrics.position_error,
) = self.pose_convergence.forward_out_distance(
state.ee_pos_seq, state.ee_quat_seq, self._goal_buffer
)
if (
self._goal_buffer.links_goal_pose is not None
and self.convergence_cfg.pose_cfg is not None
):
pose_error = [out_metrics.pose_error]
position_error = [out_metrics.position_error]
quat_error = [out_metrics.rotation_error]
link_poses = state.link_pose
for k in self._goal_buffer.links_goal_pose.keys():
if k != self.kinematics.ee_link:
current_fn = self._link_pose_convergence[k]
if current_fn.enabled:
# get link pose
current_pos = link_poses[k].position
current_quat = link_poses[k].quaternion
pose_err, pos_err, quat_err = current_fn.forward_out_distance(
current_pos, current_quat, self._goal_buffer, k
)
pose_error.append(pose_err)
position_error.append(pos_err)
quat_error.append(quat_err)
out_metrics.pose_error = cat_max(pose_error)
out_metrics.rotation_error = cat_max(quat_error)
out_metrics.position_error = cat_max(position_error)
if (
self._goal_buffer.goal_state is not None
and self.convergence_cfg.cspace_cfg is not None
and self.cspace_convergence.enabled
):
_, out_metrics.cspace_error = self.cspace_convergence.forward_target_idx(
self._goal_buffer.goal_state.position,
state.state_seq.position,
self._goal_buffer.batch_goal_state_idx,
True,
)
return out_metrics
def update_params(
self,
goal: Goal,
):
"""
Update params for the cost terms and dynamics model.
"""
super(ArmReacher, self).update_params(goal)
if goal.batch_pose_idx is not None:
self._goal_idx_update = False
if goal.goal_pose.position is not None:
self.enable_cspace_cost(False)
return True
def enable_pose_cost(self, enable: bool = True):
if enable:
self.goal_cost.enable_cost()
else:
self.goal_cost.disable_cost()
def enable_cspace_cost(self, enable: bool = True):
if enable:
self.dist_cost.enable_cost()
self.cspace_convergence.enable_cost()
else:
self.dist_cost.disable_cost()
self.cspace_convergence.disable_cost()

View File

@@ -0,0 +1,13 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
"""
"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,121 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from typing import List, Optional, Union
# Third Party
import torch
# CuRobo
from curobo.types.base import TensorDeviceType
@dataclass
class CostConfig:
weight: Union[torch.Tensor, float, List[float]]
tensor_args: TensorDeviceType = None
distance_threshold: float = 0.0
classify: bool = False
terminal: bool = False
run_weight: Optional[float] = None
dof: int = 7
vec_weight: Optional[Union[torch.Tensor, List[float], float]] = None
max_value: Optional[float] = None
hinge_value: Optional[float] = None
vec_convergence: Optional[List[float]] = None
threshold_value: Optional[float] = None
return_loss: bool = False
def __post_init__(self):
self.weight = self.tensor_args.to_device(self.weight)
if len(self.weight.shape) == 0:
self.weight = torch.tensor(
[self.weight], device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
if self.vec_weight is not None:
self.vec_weight = self.tensor_args.to_device(self.vec_weight)
if self.max_value is not None:
self.max_value = self.tensor_args.to_device(self.max_value)
if self.hinge_value is not None:
self.hinge_value = self.tensor_args.to_device(self.hinge_value)
if self.threshold_value is not None:
self.threshold_value = self.tensor_args.to_device(self.threshold_value)
def update_vec_weight(self, vec_weight):
self.vec_weight = self.tensor_args.to_device(vec_weight)
class CostBase(torch.nn.Module, CostConfig):
def __init__(self, config: Optional[CostConfig] = None):
"""Initialize class
Args:
config (Optional[CostConfig], optional): To initialize this class directly, pass a config.
If this is a base class, it's assumed that you will initialize the child class with `CostConfig`.
Defaults to None.
"""
self._run_weight_vec = None
super(CostBase, self).__init__()
if config is not None:
CostConfig.__init__(self, **vars(config))
CostBase._init_post_config(self)
self._batch_size = -1
self._horizon = -1
self._dof = -1
self._dt = 1
def _init_post_config(self):
self._weight = self.weight.clone()
self.cost_fn = None
self._cost_enabled = True
self._z_scalar = self.tensor_args.to_device(0.0)
if torch.sum(self.weight) == 0.0:
self.disable_cost()
def forward(self, q):
batch_size = q.shape[0]
horizon = q.shape[1]
q = q.view(batch_size * horizon, q.shape[2])
res = self.cost_fn(q)
res = res.view(batch_size, horizon)
res += self.distance_threshold
res = torch.nn.functional.relu(res, inplace=True)
if self.classify:
res = torch.where(res > 0, res + 1.0, res)
cost = self.weight * res
return cost
def disable_cost(self):
self.weight.copy_(self._weight * 0.0)
self._cost_enabled = False
def enable_cost(self):
self.weight.copy_(self._weight.clone())
if torch.sum(self.weight) == 0.0:
self._cost_enabled = False
else:
self._cost_enabled = True
def update_weight(self, weight: float):
if weight == 0.0:
self.disable_cost()
else:
self.weight.copy_(self._weight * 0.0 + weight)
@property
def enabled(self):
return self._cost_enabled
def update_dt(self, dt: Union[float, torch.Tensor]):
self._dt = dt

View File

@@ -0,0 +1,315 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from enum import Enum
from typing import Optional
# Third Party
import torch
import warp as wp
# CuRobo
from curobo.util.warp import init_warp
# Local Folder
from .cost_base import CostBase, CostConfig
wp.set_module_options({"fast_math": False})
class DistType(Enum):
L1 = 0
L2 = 1
SQUARED_L2 = 2
@dataclass
class DistCostConfig(CostConfig):
dist_type: DistType = DistType.L2
use_null_space: bool = False
def __post_init__(self):
return super().__post_init__()
@torch.jit.script
def L2_DistCost_jit(vec_weight, disp_vec):
return torch.norm(vec_weight * disp_vec, p=2, dim=-1, keepdim=False)
@torch.jit.script
def fwd_SQL2_DistCost_jit(vec_weight, disp_vec):
return torch.sum(torch.square(vec_weight * disp_vec), dim=-1, keepdim=False)
@torch.jit.script
def fwd_L1_DistCost_jit(vec_weight, disp_vec):
return torch.sum(torch.abs(vec_weight * disp_vec), dim=-1, keepdim=False)
@torch.jit.script
def L2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
return torch.norm(weight * vec_weight * (g_vec - c_vec), p=2, dim=-1, keepdim=False)
@torch.jit.script
def fwd_SQL2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
return torch.sum(torch.square(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
@torch.jit.script
def fwd_L1_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
return torch.sum(torch.abs(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
@wp.kernel
def forward_l2_warp(
pos: wp.array(dtype=wp.float32),
target: wp.array(dtype=wp.float32),
target_idx: wp.array(dtype=wp.int32),
weight: wp.array(dtype=wp.float32),
run_weight: wp.array(dtype=wp.float32),
vec_weight: wp.array(dtype=wp.float32),
out_cost: wp.array(dtype=wp.float32),
out_grad_p: wp.array(dtype=wp.float32),
write_grad: wp.uint8, # this should be a bool
batch_size: wp.int32,
horizon: wp.int32,
dof: wp.int32,
):
tid = wp.tid()
# initialize variables:
b_id = wp.int32(0)
h_id = wp.int32(0)
d_id = wp.int32(0)
b_addrs = wp.int32(0)
target_id = wp.int32(0)
w = wp.float32(0.0)
c_p = wp.float32(0.0)
target_p = wp.float32(0.0)
g_p = wp.float32(0.0)
r_w = wp.float32(0.0)
c_total = wp.float32(0.0)
# we launch batch * horizon * dof kernels
b_id = tid / (horizon * dof)
h_id = (tid - (b_id * horizon * dof)) / dof
d_id = tid - (b_id * horizon * dof + h_id * dof)
if b_id >= batch_size or h_id >= horizon or d_id >= dof:
return
# read weights:
w = weight[0]
r_w = run_weight[h_id]
w = r_w * w
r_w = vec_weight[d_id]
w = r_w * w
if w == 0.0:
return
# compute cost:
b_addrs = b_id * horizon * dof + h_id * dof + d_id
# read buffers:
c_p = pos[b_addrs]
target_id = target_idx[b_id]
target_id = target_id * dof + d_id
target_p = target[target_id]
error = c_p - target_p
if r_w >= 1.0 and w > 100.0:
c_total = w * wp.log2(wp.cosh(50.0 * error))
g_p = w * 50.0 * wp.sinh(50.0 * error) / (wp.cosh(50.0 * error))
else:
c_total = w * error * error
g_p = 2.0 * w * error
out_cost[b_addrs] = c_total
# compute gradient
if write_grad == 1:
out_grad_p[b_addrs] = g_p
# create a bound cost tensor:
class L2DistFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
pos,
target,
target_idx,
weight,
run_weight,
vec_weight,
out_cost,
out_cost_v,
out_gp,
):
wp_device = wp.device_from_torch(pos.device)
b, h, dof = pos.shape
# print(target)
wp.launch(
kernel=forward_l2_warp,
dim=b * h * dof,
inputs=[
wp.from_torch(pos.detach().reshape(-1), dtype=wp.float32),
wp.from_torch(target.view(-1), dtype=wp.float32),
wp.from_torch(target_idx.view(-1), dtype=wp.int32),
wp.from_torch(weight, dtype=wp.float32),
wp.from_torch(run_weight.view(-1), dtype=wp.float32),
wp.from_torch(vec_weight.view(-1), dtype=wp.float32),
wp.from_torch(out_cost_v.view(-1), dtype=wp.float32),
wp.from_torch(out_gp.view(-1), dtype=wp.float32),
pos.requires_grad,
b,
h,
dof,
],
device=wp_device,
stream=wp.stream_from_torch(pos.device),
)
# cost = torch.linalg.norm(out_cost_v, dim=-1)
# if pos.requires_grad:
# out_gp = out_gp * torch.nan_to_num( 1.0/cost.unsqueeze(-1), 0.0)
cost = torch.sum(out_cost_v, dim=-1)
ctx.save_for_backward(out_gp)
return cost
@staticmethod
def backward(ctx, grad_out_cost):
(p_grad,) = ctx.saved_tensors
p_g = None
if ctx.needs_input_grad[0]:
p_g = p_grad
return p_g, None, None, None, None, None, None, None, None
class DistCost(CostBase, DistCostConfig):
def __init__(self, config: Optional[DistCostConfig] = None):
if config is not None:
DistCostConfig.__init__(self, **vars(config))
CostBase.__init__(self)
self._init_post_config()
init_warp()
def _init_post_config(self):
if self.vec_weight is not None:
self.vec_weight = self.tensor_args.to_device(self.vec_weight)
if not self.use_null_space:
self.vec_weight = self.vec_weight * 0.0 + 1.0
def update_batch_size(self, batch, horizon, dof):
if self._batch_size != batch or self._horizon != horizon or self._dof != dof:
self._out_cv_buffer = torch.zeros(
(batch, horizon, dof), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self._out_c_buffer = torch.zeros(
(batch, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self._out_g_buffer = torch.zeros(
(batch, horizon, dof), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self._batch_size = batch
self._horizon = horizon
self._dof = dof
if self.vec_weight is None:
self.vec_weight = torch.ones(
(1, 1, self._dof), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
def forward(self, disp_vec, RETURN_GOAL_DIST=False):
if self.dist_type == DistType.L2:
# dist = torch.norm(disp_vec, p=2, dim=-1, keepdim=False)
dist = L2_DistCost_jit(self.vec_weight, disp_vec)
elif self.dist_type == DistType.SQUARED_L2:
# cost = weight * (0.5 * torch.square(torch.norm(disp_vec, p=2, dim=-1)))
# dist = torch.sum(torch.square(disp_vec), dim=-1, keepdim=False)
dist = SQL2_DistCost_jit(self.vec_weight, disp_vec)
elif self.dist_type == DistType.L1:
# dist = torch.sum(torch.abs(disp_vec), dim=-1, keepdim=False)
dist = L1_DistCost_jit(self.vec_weight, disp_vec)
cost = self.weight * dist
if self.terminal and self.run_weight is not None:
if self._run_weight_vec is None or self._run_weight_vec.shape[1] != cost.shape[1]:
self._run_weight_vec = torch.ones(
(1, cost.shape[1]), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self._run_weight_vec[:, :-1] *= self.run_weight
if RETURN_GOAL_DIST:
return cost, dist
return cost
def forward_target(self, goal_vec, current_vec, RETURN_GOAL_DIST=False):
if self.dist_type == DistType.L2:
# dist = torch.norm(disp_vec, p=2, dim=-1, keepdim=False)
cost = L2_DistCost_target_jit(self.vec_weight, goal_vec, current_vec, self.weight)
elif self.dist_type == DistType.SQUARED_L2:
# cost = weight * (0.5 * torch.square(torch.norm(disp_vec, p=2, dim=-1)))
# dist = torch.sum(torch.square(disp_vec), dim=-1, keepdim=False)
cost = fwd_SQL2_DistCost_target_jit(self.vec_weight, goal_vec, current_vec, self.weight)
elif self.dist_type == DistType.L1:
# dist = torch.sum(torch.abs(disp_vec), dim=-1, keepdim=False)
cost = fwd_L1_DistCost_target_jit(self.vec_weight, goal_vec, current_vec, self.weight)
dist = cost
if self.terminal and self.run_weight is not None:
if self._run_weight_vec is None or self._run_weight_vec.shape[1] != cost.shape[1]:
self._run_weight_vec = torch.ones(
(1, cost.shape[1]), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self._run_weight_vec[:, :-1] *= self.run_weight
cost = self._run_weight_vec * dist
if RETURN_GOAL_DIST:
return cost, dist / self.weight
return cost
def forward_target_idx(self, goal_vec, current_vec, goal_idx, RETURN_GOAL_DIST=False):
b, h, dof = current_vec.shape
self.update_batch_size(b, h, dof)
if self.terminal and self.run_weight is not None:
if self._run_weight_vec is None or self._run_weight_vec.shape[1] != h:
self._run_weight_vec = torch.ones(
(1, h), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self._run_weight_vec[:, :-1] *= self.run_weight
else:
raise NotImplementedError("terminal flag needs to be set to true")
if self.dist_type == DistType.L2:
# print(goal_idx.shape, goal_vec.shape)
cost = L2DistFunction.apply(
current_vec,
goal_vec,
goal_idx,
self.weight,
self._run_weight_vec,
self.vec_weight,
self._out_c_buffer,
self._out_cv_buffer,
self._out_g_buffer,
)
# cost = torch.linalg.norm(cost, dim=-1)
else:
raise NotImplementedError()
# print(cost.shape, cost[:,-1])
if RETURN_GOAL_DIST:
return cost, (cost / torch.sqrt((self.weight * self._run_weight_vec)))
return cost

View File

@@ -0,0 +1,107 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from itertools import product
# Third Party
import torch
# Local Folder
from .cost_base import CostBase, CostConfig
@dataclass
class ManipulabilityCostConfig(CostConfig):
use_joint_limits: bool = False
def __post_init__(self):
return super().__post_init__()
class ManipulabilityCost(CostBase, ManipulabilityCostConfig):
def __init__(self, config: ManipulabilityCostConfig):
ManipulabilityCostConfig.__init__(self, **vars(config))
CostBase.__init__(self)
self.i_mat = torch.ones(
(6, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.delta_vector = torch.zeros(
(64, 1, 1, 6, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
x = [i for i in product(range(2), repeat=6)]
self.delta_vector[:, 0, 0, :, 0] = torch.as_tensor(
x, device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.delta_vector[self.delta_vector == 0] = -1.0
if self.cost_fn is None:
if self.use_joint_limits and self.joint_limits is not None:
self.cost_fn = self.joint_limited_manipulability_delta
else:
self.cost_fn = self.manipulability
def forward(self, jac_batch, q, qdot):
b, h, n = q.shape
if self.use_nn:
q = q.view(b * h, n)
score = self.cost_fn(q, jac_batch, qdot)
if self.use_nn:
score = score.view(b, h)
score[score > self.hinge_value] = self.hinge_value
score = (self.hinge_value / score) - 1
cost = self.weight * score
return cost
def manipulability(self, q, jac_batch, qdot=None):
with torch.cuda.amp.autocast(enabled=False):
J_J_t = torch.matmul(jac_batch, jac_batch.transpose(-2, -1))
score = torch.sqrt(torch.det(J_J_t))
score[score != score] = 0.0
return score
def joint_limited_manipulability_delta(self, q, jac_batch, qdot=None):
# q is [b,h,dof]
q_low = q - self.joint_limits[:, 0]
q_high = q - self.joint_limits[:, 1]
d_h_1 = torch.square(self.joint_limits[:, 1] - self.joint_limits[:, 0]) * (q_low + q_high)
d_h_2 = 4.0 * (torch.square(q_low) * torch.square(q_high))
d_h = torch.div(d_h_1, d_h_2)
dh_term = 1.0 / torch.sqrt(1 + torch.abs(d_h))
f_ten = torch.tensor(1.0, **self.tensor_args)
q_low = torch.abs(q_low)
q_high = torch.abs(q_high)
p_plus = torch.where(q_low > q_high, dh_term, f_ten).unsqueeze(-2)
p_minus = torch.where(q_low > q_high, f_ten, dh_term).unsqueeze(-2)
j_sign = torch.sign(jac_batch)
l_delta = torch.sign(self.delta_vector) * j_sign
L = torch.where(l_delta < 0.0, p_minus, p_plus)
with torch.cuda.amp.autocast(enabled=False):
w_J = L * jac_batch
J_J_t = torch.matmul(w_J, w_J.transpose(-2, -1))
score = torch.sqrt(torch.det(J_J_t))
# get actual score:
min_score = torch.min(score, dim=0)[0]
max_score = torch.max(score, dim=0)[0]
score = min_score / max_score
score[score != score] = 0.0
return score

View File

@@ -0,0 +1,765 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
# Third Party
import torch
from torch.autograd import Function
# CuRobo
from curobo.curobolib.geom import get_pose_distance, get_pose_distance_backward
from curobo.rollout.rollout_base import Goal
from curobo.types.math import OrientationError, Pose
# Local Folder
from .cost_base import CostBase, CostConfig
class PoseErrorType(Enum):
SINGLE_GOAL = 0 #: Distance will be computed to a single goal pose
BATCH_GOAL = 1 #: Distance will be computed pairwise between query batch and goal batch
GOALSET = 2 #: Shortest Distance will be computed to a goal set
BATCH_GOALSET = 3 #: Shortest Distance to a batch goal set
@dataclass
class PoseCostConfig(CostConfig):
cost_type: PoseErrorType = PoseErrorType.BATCH_GOAL
use_metric: bool = False
run_vec_weight: Optional[List[float]] = None
def __post_init__(self):
if self.run_vec_weight is not None:
self.run_vec_weight = self.tensor_args.to_device(self.run_vec_weight)
else:
self.run_vec_weight = torch.ones(
6, device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
if self.vec_weight is None:
self.vec_weight = torch.ones(
6, device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
if self.vec_convergence is None:
self.vec_convergence = torch.zeros(
2, device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
return super().__post_init__()
@torch.jit.script
def backward_PoseError_jit(grad_g_dist, grad_out_distance, weight, g_vec):
grad_vec = grad_g_dist + (grad_out_distance * weight)
grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec
return grad
# full method:
@torch.jit.script
def backward_full_PoseError_jit(
grad_out_distance, grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
):
p_grad = (grad_g_dist + (grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
q_grad = (grad_r_err + (grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
# p_grad = ((grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
# q_grad = ((grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
return p_grad, q_grad
class PoseErrorDistance(Function):
@staticmethod
def forward(
ctx,
current_position,
goal_position,
current_quat,
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
out_p_grad,
out_q_grad,
batch_size,
horizon,
mode=PoseErrorType.BATCH_GOAL.value,
num_goals=1,
use_metric=False,
):
# out_distance = current_position[..., 0].detach().clone() * 0.0
# out_position_distance = out_distance.detach().clone()
# out_rotation_distance = out_distance.detach().clone()
# out_vec = (
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
# * 0.0
# )
# out_idx = out_distance.clone().to(dtype=torch.long)
(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
) = get_pose_distance(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
current_position.contiguous(),
goal_position,
current_quat.contiguous(),
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
batch_size,
horizon,
mode,
num_goals,
current_position.requires_grad,
True,
use_metric,
)
ctx.save_for_backward(out_p_vec, out_r_vec, weight, out_p_grad, out_q_grad)
return out_distance, out_position_distance, out_rotation_distance, out_idx # .view(-1,1)
@staticmethod
def backward(ctx, grad_out_distance, grad_g_dist, grad_r_err, grad_out_idx):
(g_vec_p, g_vec_q, weight, out_grad_p, out_grad_q) = ctx.saved_tensors
pos_grad = None
quat_grad = None
batch_size = g_vec_p.shape[0] * g_vec_p.shape[1]
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
pos_grad, quat_grad = get_pose_distance_backward(
out_grad_p,
out_grad_q,
grad_out_distance.contiguous(),
grad_g_dist.contiguous(),
grad_r_err.contiguous(),
weight,
g_vec_p,
g_vec_q,
batch_size,
use_distance=True,
)
# pos_grad, quat_grad = backward_full_PoseError_jit(
# grad_out_distance,
# grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
# )
elif ctx.needs_input_grad[0]:
pos_grad = backward_PoseError_jit(grad_g_dist, grad_out_distance, p_w, g_vec_p)
# grad_vec = grad_g_dist + (grad_out_distance * weight[1])
# pos_grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec[..., 4:]
elif ctx.needs_input_grad[2]:
quat_grad = backward_PoseError_jit(grad_r_err, grad_out_distance, q_w, g_vec_q)
# grad_vec = grad_r_err + (grad_out_distance * weight[0])
# quat_grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec[..., :4]
return (
pos_grad,
None,
quat_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class PoseLoss(Function):
@staticmethod
def forward(
ctx,
current_position,
goal_position,
current_quat,
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
out_p_grad,
out_q_grad,
batch_size,
horizon,
mode=PoseErrorType.BATCH_GOAL.value,
num_goals=1,
use_metric=False,
):
# out_distance = current_position[..., 0].detach().clone() * 0.0
# out_position_distance = out_distance.detach().clone()
# out_rotation_distance = out_distance.detach().clone()
# out_vec = (
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
# * 0.0
# )
# out_idx = out_distance.clone().to(dtype=torch.long)
(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
) = get_pose_distance(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
current_position.contiguous(),
goal_position,
current_quat.contiguous(),
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
batch_size,
horizon,
mode,
num_goals,
current_position.requires_grad,
False,
use_metric,
)
ctx.save_for_backward(out_p_vec, out_r_vec)
return out_distance
@staticmethod
def backward(ctx, grad_out_distance): # , grad_g_dist, grad_r_err, grad_out_idx):
pos_grad = None
quat_grad = None
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
pos_grad = g_vec_p * grad_out_distance.unsqueeze(1)
quat_grad = g_vec_q * grad_out_distance.unsqueeze(1)
pos_grad = pos_grad.unsqueeze(-2)
quat_grad = quat_grad.unsqueeze(-2)
elif ctx.needs_input_grad[0]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
pos_grad = g_vec_p * grad_out_distance.unsqueeze(1)
pos_grad = pos_grad.unsqueeze(-2)
elif ctx.needs_input_grad[2]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
quat_grad = g_vec_q * grad_out_distance.unsqueeze(1)
quat_grad = quat_grad.unsqueeze(-2)
return (
pos_grad,
None,
quat_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class PoseError(Function):
@staticmethod
def forward(
ctx,
current_position,
goal_position,
current_quat,
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
out_p_grad,
out_q_grad,
batch_size,
horizon,
mode=PoseErrorType.BATCH_GOAL.value,
num_goals=1,
use_metric=False,
):
# out_distance = current_position[..., 0].detach().clone() * 0.0
# out_position_distance = out_distance.detach().clone()
# out_rotation_distance = out_distance.detach().clone()
# out_vec = (
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
# * 0.0
# )
# out_idx = out_distance.clone().to(dtype=torch.long)
(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
) = get_pose_distance(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
current_position.contiguous(),
goal_position,
current_quat.contiguous(),
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
batch_size,
horizon,
mode,
num_goals,
current_position.requires_grad,
False,
use_metric,
)
ctx.save_for_backward(out_p_vec, out_r_vec)
return out_distance
@staticmethod
def backward(ctx, grad_out_distance): # , grad_g_dist, grad_r_err, grad_out_idx):
pos_grad = None
quat_grad = None
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
pos_grad = g_vec_p
quat_grad = g_vec_q
elif ctx.needs_input_grad[0]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
pos_grad = g_vec_p
elif ctx.needs_input_grad[2]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
quat_grad = g_vec_q
return (
pos_grad,
None,
quat_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class PoseCost(CostBase, PoseCostConfig):
def __init__(self, config: PoseCostConfig):
PoseCostConfig.__init__(self, **vars(config))
CostBase.__init__(self)
self.rot_weight = self.vec_weight[0:3]
self.pos_weight = self.vec_weight[3:6]
self._vec_convergence = self.tensor_args.to_device(self.vec_convergence)
self._batch_size = 0
self._horizon = 0
def update_batch_size(self, batch_size, horizon):
if batch_size != self._batch_size or horizon != self._horizon:
# batch_size = b*h
self.out_distance = torch.zeros(
(batch_size, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.out_position_distance = torch.zeros(
(batch_size, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.out_rotation_distance = torch.zeros(
(batch_size, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.out_idx = torch.zeros(
(batch_size, horizon), device=self.tensor_args.device, dtype=torch.int32
)
self.out_p_vec = torch.zeros(
(batch_size, horizon, 3),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
self.out_q_vec = torch.zeros(
(batch_size, horizon, 4),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
self.out_p_grad = torch.zeros(
(batch_size, horizon, 3),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
self.out_q_grad = torch.zeros(
(batch_size, horizon, 4),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
if self._run_weight_vec is None or self._run_weight_vec.shape[1] != horizon:
self._run_weight_vec = torch.ones(
(1, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
if self.terminal and self.run_weight is not None:
self._run_weight_vec[:, :-1] *= self.run_weight
self._batch_size = batch_size
self._horizon = horizon
def _forward_goal_distribution(self, ee_pos_batch, ee_rot_batch, ee_goal_pos, ee_goal_rot):
ee_goal_pos = ee_goal_pos.unsqueeze(1)
ee_goal_pos = ee_goal_pos.unsqueeze(1)
ee_goal_rot = ee_goal_rot.unsqueeze(1)
ee_goal_rot = ee_goal_rot.unsqueeze(1)
error, rot_error, pos_error = self.forward_single_goal(
ee_pos_batch, ee_rot_batch, ee_goal_pos, ee_goal_rot
)
min_idx = torch.argmin(error[:, :, -1], dim=0)
min_idx = min_idx.unsqueeze(1).expand(error.shape[1], error.shape[2])
if len(min_idx.shape) == 2:
min_idx = min_idx[0, 0]
error = error[min_idx]
rot_error = rot_error[min_idx]
pos_error = pos_error[min_idx]
return error, rot_error, pos_error, min_idx
def _forward_single_goal(self, ee_pos_batch, ee_rot_batch, ee_goal_pos, ee_goal_rot):
# b, h, _ = ee_pos_batch.shape
d_g_ee = ee_pos_batch - ee_goal_pos
position_err = torch.norm(self.pos_weight * d_g_ee, dim=-1)
goal_dist = position_err # .clone()
rot_err = OrientationError.apply(ee_goal_rot, ee_rot_batch, ee_rot_batch.clone()).squeeze(
-1
)
rot_err_c = rot_err.clone()
goal_dist_c = goal_dist.clone()
# clamp:
if self.vec_convergence[1] > 0.0:
position_err = torch.where(
position_err > self.vec_convergence[1], position_err, position_err * 0.0
)
if self.vec_convergence[0] > 0.0:
rot_err = torch.where(rot_err > self.vec_convergence[0], rot_err, rot_err * 0.0)
# rot_err = torch.norm(goal_orient_vec, dim = -1)
cost = self.weight[0] * rot_err + self.weight[1] * position_err
# dimension should be bacth * traj_length
return cost, rot_err_c, goal_dist_c
def _forward_pytorch(self, ee_pos_batch, ee_rot_batch, ee_goal_pos, ee_goal_rot):
if self.cost_type == PoseErrorType.SINGLE_GOAL:
cost, r_err, g_dist = self.forward_single_goal(
ee_pos_batch, ee_rot_batch, ee_goal_pos, ee_goal_rot
)
elif self.cost_type == PoseErrorType.BATCH_GOAL:
cost, r_err, g_dist = self.forward_single_goal(
ee_pos_batch, ee_rot_batch, ee_goal_pos, ee_goal_rot
)
else:
cost, r_err, g_dist = self.forward_goal_distribution(
ee_pos_batch, ee_rot_batch, ee_goal_pos, ee_goal_rot
)
if self.terminal and self.run_weight is not None:
cost[:, :-1] *= self.run_weight
return cost, r_err, g_dist
def _update_cost_type(self, ee_goal_pos, ee_pos_batch, num_goals):
d_g = len(ee_goal_pos.shape)
b_sze = ee_goal_pos.shape[0]
if d_g == 2 and b_sze == 1: # 1, 3
self.cost_type = PoseErrorType.SINGLE_GOAL
elif d_g == 2 and b_sze == ee_pos_batch.shape[0]: # b, 3
self.cost_type = PoseErrorType.BATCH_GOAL
elif d_g == 3:
self.cost_type = PoseErrorType.GOALSET
elif len(ee_goal_pos.shape) == 4 and b_sze == ee_pos_bath.shape[0]:
self.cost_type = PoseErrorType.BATCH_GOALSET
def forward_out_distance(
self, ee_pos_batch, ee_rot_batch, goal: Goal, link_name: Optional[str] = None
):
if link_name is None:
goal_pose = goal.goal_pose
else:
goal_pose = goal.links_goal_pose[link_name]
ee_goal_pos = goal_pose.position
ee_goal_rot = goal_pose.quaternion
num_goals = goal_pose.n_goalset
self._update_cost_type(ee_goal_pos, ee_pos_batch, num_goals)
b, h, _ = ee_pos_batch.shape
self.update_batch_size(b, h)
distance, g_dist, r_err, idx = PoseErrorDistance.apply(
ee_pos_batch, # .view(-1, 3).contiguous(),
ee_goal_pos,
ee_rot_batch, # .view(-1, 4).contiguous(),
ee_goal_rot,
self.vec_weight,
self.weight,
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
goal.batch_pose_idx,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
self.out_p_vec,
self.out_q_vec,
self.out_idx,
self.out_p_grad,
self.out_q_grad,
b,
h,
self.cost_type.value,
num_goals,
self.use_metric,
)
# print(goal.batch_pose_idx.shape)
cost = distance # .view(b, h)#.clone()
r_err = r_err # .view(b, h)
g_dist = g_dist # .view(b, h)
idx = idx # .view(b, h)
return cost, r_err, g_dist
def forward(self, ee_pos_batch, ee_rot_batch, goal: Goal, link_name: Optional[str] = None):
if link_name is None:
goal_pose = goal.goal_pose
else:
goal_pose = goal.links_goal_pose[link_name]
ee_goal_pos = goal_pose.position
ee_goal_rot = goal_pose.quaternion
num_goals = goal_pose.n_goalset
self._update_cost_type(ee_goal_pos, ee_pos_batch, num_goals)
b, h, _ = ee_pos_batch.shape
self.update_batch_size(b, h)
# return self.out_distance
# print(b,h, ee_goal_pos.shape)
if self.return_loss:
distance = PoseLoss.apply(
ee_pos_batch,
ee_goal_pos,
ee_rot_batch, # .view(-1, 4).contiguous(),
ee_goal_rot,
self.vec_weight,
self.weight,
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
goal.batch_pose_idx,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
self.out_p_vec,
self.out_q_vec,
self.out_idx,
self.out_p_grad,
self.out_q_grad,
b,
h,
self.cost_type.value,
num_goals,
self.use_metric,
)
else:
distance = PoseError.apply(
ee_pos_batch,
ee_goal_pos,
ee_rot_batch, # .view(-1, 4).contiguous(),
ee_goal_rot,
self.vec_weight,
self.weight,
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
goal.batch_pose_idx,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
self.out_p_vec,
self.out_q_vec,
self.out_idx,
self.out_p_grad,
self.out_q_grad,
b,
h,
self.cost_type.value,
num_goals,
self.use_metric,
)
cost = distance
# print(cost.shape)
return cost
def forward_pose(
self,
goal_pose: Pose,
query_pose: Pose,
batch_pose_idx: torch.Tensor,
mode: PoseErrorType = PoseErrorType.BATCH_GOAL,
):
ee_goal_pos = goal_pose.position
ee_goal_quat = goal_pose.quaternion
self.cost_type = mode
self.update_batch_size(query_pose.position.shape[0], query_pose.position.shape[1])
b = query_pose.position.shape[0]
h = query_pose.position.shape[1]
num_goals = 1
if self.return_loss:
distance = PoseLoss.apply(
query_pose.position.unsqueeze(1),
ee_goal_pos,
query_pose.quaternion.unsqueeze(1),
ee_goal_quat,
self.vec_weight,
self.weight,
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
batch_pose_idx,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
self.out_p_vec,
self.out_q_vec,
self.out_idx,
self.out_p_grad,
self.out_q_grad,
b,
h,
self.cost_type.value,
num_goals,
self.use_metric,
)
else:
distance = PoseError.apply(
query_pose.position.unsqueeze(1),
ee_goal_pos,
query_pose.quaternion.unsqueeze(1),
ee_goal_quat,
self.vec_weight,
self.weight,
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
batch_pose_idx,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
self.out_p_vec,
self.out_q_vec,
self.out_idx,
self.out_p_grad,
self.out_q_grad,
b,
h,
self.cost_type.value,
num_goals,
self.use_metric,
)
return distance

View File

@@ -0,0 +1,214 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from typing import Optional, Union
# Third Party
import torch
# CuRobo
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollision
from curobo.rollout.cost.cost_base import CostBase, CostConfig
from curobo.rollout.dynamics_model.integration_utils import interpolate_kernel, sum_matrix
@dataclass
class PrimitiveCollisionCostConfig(CostConfig):
"""Create Collision Cost Configuration."""
#: WorldCollision instance to use for distance queries.
world_coll_checker: Optional[WorldCollision] = None
#: Sweep for collisions between timesteps in a trajectory.
use_sweep: bool = False
use_sweep_kernel: bool = False
sweep_steps: int = 4
#: Speed metric scales the collision distance by sphere velocity (similar to CHOMP Planner
#: ICRA'09). This prevents the optimizer from speeding through obstacles to minimize cost and
#: instead encourages the robot to move around the obstacle.
use_speed_metric: bool = False
#: dt to use for computation of velocity and acceleration through central difference for
#: speed metric. Value less than 1 is better as that leads to different scaling between
#: acceleration and velocity.
speed_dt: Union[torch.Tensor, float] = 0.01
#: The distance outside collision at which to activate the cost. Having a non-zero value enables
#: the robot to move slowly when within this distance to an obstacle. This enables our
#: post optimization interpolation to not hit any obstacles.
activation_distance: Union[torch.Tensor, float] = 0.0
#: Setting this flag to true will sum the distance across spheres of the robot.
sum_distance: bool = True
def __post_init__(self):
if isinstance(self.speed_dt, float):
self.speed_dt = self.tensor_args.to_device([self.speed_dt])
if isinstance(self.activation_distance, float):
self.activation_distance = self.tensor_args.to_device([self.activation_distance])
return super().__post_init__()
class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
def __init__(self, config: PrimitiveCollisionCostConfig):
"""Creates a primitive collision cost instance.
See note on :ref:`collision_checking_note` for details on the cost formulation.
Args:
config: Cost
"""
PrimitiveCollisionCostConfig.__init__(self, **vars(config))
CostBase.__init__(self)
self._og_speed_dt = self.speed_dt.clone()
self.batch_size = -1
self._horizon = -1
self._n_spheres = -1
self.t_mat = None
if self.classify:
self.coll_check_fn = self.world_coll_checker.get_sphere_collision
self.sweep_check_fn = self.world_coll_checker.get_swept_sphere_collision
else:
self.coll_check_fn = self.world_coll_checker.get_sphere_distance
self.sweep_check_fn = self.world_coll_checker.get_swept_sphere_distance
self.sampled_spheres = None
self.sum_mat = None #
if self.use_sweep:
# if self.use_sweep_kernel and (
# type(self.world_coll_checker) in [WorldMeshCollision, WorldPrimitiveCollision]
# ):
# TODO: Implement sweep for nvblox collision checker.
self.forward = self.sweep_kernel_fn
# else:
# self.forward = self.discrete_fn
else:
self.forward = self.discrete_fn
self.int_mat = None
self._fd_matrix = None
self._collision_query_buffer = CollisionQueryBuffer()
def sweep_kernel_fn(self, robot_spheres_in, env_query_idx: Optional[torch.Tensor] = None):
self._collision_query_buffer.update_buffer_shape(
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
)
dist = self.sweep_check_fn(
robot_spheres_in,
self._collision_query_buffer,
self.weight,
sweep_steps=self.sweep_steps,
activation_distance=self.activation_distance,
speed_dt=self.speed_dt,
enable_speed_metric=self.use_speed_metric,
env_query_idx=env_query_idx,
return_loss=self.return_loss,
)
if self.classify:
cost = weight_collision(dist, self.weight, self.sum_distance)
else:
cost = weight_distance(dist, self.weight, self.sum_distance)
return cost
def sweep_fn(self, robot_spheres_in, env_query_idx: Optional[torch.Tensor] = None):
batch_size, horizon, n_spheres, _ = robot_spheres_in.shape
# add intermediate spheres to account for discretization:
new_horizon = (horizon - 1) * self.sweep_steps
if self.int_mat is None:
self.int_mat = interpolate_kernel(horizon, self.sweep_steps, self.tensor_args)
self.int_mat_t = self.int_mat.transpose(0, 1)
self.int_sum_mat = sum_matrix(horizon, self.sweep_steps, self.tensor_args)
sampled_spheres = (
(robot_spheres_in.transpose(1, 2).transpose(2, 3) @ self.int_mat_t)
.transpose(2, 3)
.transpose(1, 2)
.contiguous()
)
# robot_spheres = sampled_spheres.view(batch_size * new_horizon * n_spheres, 4)
# self.update_batch_size(batch_size * new_horizon * n_spheres)
self._collision_query_buffer.update_buffer_shape(
sampled_spheres.shape, self.tensor_args, self.world_coll_checker.collision_types
)
dist = self.coll_check_fn(
sampled_spheres.contiguous(),
self._collision_query_buffer,
self.weight,
activation_distance=self.activation_distance,
env_query_idx=env_query_idx,
return_loss=self.return_loss,
)
dist = dist.view(batch_size, new_horizon, n_spheres)
if self.classify:
cost = weight_sweep_collision(self.int_sum_mat, dist, self.weight, self.sum_distance)
else:
cost = weight_sweep_distance(self.int_sum_mat, dist, self.weight, self.sum_distance)
return cost
def discrete_fn(self, robot_spheres_in, env_query_idx: Optional[torch.Tensor] = None):
self._collision_query_buffer.update_buffer_shape(
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
)
dist = self.coll_check_fn(
robot_spheres_in,
self._collision_query_buffer,
self.weight,
env_query_idx=env_query_idx,
activation_distance=self.activation_distance,
return_loss=self.return_loss,
)
if self.classify:
cost = weight_collision(dist, self.weight, self.sum_distance)
else:
cost = weight_distance(dist, self.weight, self.sum_distance)
return cost
def update_dt(self, dt: Union[float, torch.Tensor]):
self.speed_dt[:] = dt # / self._og_speed_dt
return super().update_dt(dt)
def get_gradient_buffer(self):
return self._collision_query_buffer.get_gradient_buffer()
@torch.jit.script
def weight_sweep_distance(int_mat, dist, weight, sum_cost: bool):
dist = torch.sum(dist, dim=-1)
dist = dist @ int_mat
return dist
@torch.jit.script
def weight_sweep_collision(int_mat, dist, weight, sum_cost: bool):
dist = torch.sum(dist, dim=-1)
dist = torch.where(dist > 0, dist + 1.0, dist)
dist = dist @ int_mat
return dist
@torch.jit.script
def weight_distance(dist, weight, sum_cost: bool):
if sum_cost:
dist = torch.sum(dist, dim=-1)
return dist
@torch.jit.script
def weight_collision(dist, weight, sum_cost: bool):
if sum_cost:
dist = torch.sum(dist, dim=-1)
dist = torch.where(dist > 0, dist + 1.0, dist)
return dist

View File

@@ -0,0 +1,69 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
"""
Distance cost projected into the null-space of the Jacobian
"""
# Standard Library
from dataclasses import dataclass
from enum import Enum
# Third Party
import torch
# Local Folder
from .dist_cost import DistCost, DistCostConfig
class ProjType(Enum):
IDENTITY = 0
PSEUDO_INVERSE = 1
@dataclass
class ProjectedDistCostConfig(DistCostConfig):
eps: float = 1e-4
proj_type: ProjType = ProjType.IDENTITY
def __post_init__(self):
return super().__post_init__()
class ProjectedDistCost(DistCost, ProjectedDistCostConfig):
def __init__(self, config: ProjectedDistCostConfig):
ProjectedDistCostConfig.__init__(self, **vars(config))
DistCost.__init__(self)
self.I = torch.eye(self.dof, device=self.tensor_args.device, dtype=self.tensor_args.dtype)
self.task_I = torch.eye(6, device=self.tensor_args.device, dtype=self.tensor_args.dtype)
def forward(self, disp_vec, jac_batch=None):
disp_vec = self.vec_weight * disp_vec
if self.proj_type == ProjType.PSEUDO_INVERSE:
disp_vec_projected = self.get_pinv_null_disp(disp_vec, jac_batch)
elif self.proj_type == ProjType.IDENTITY:
disp_vec_projected = disp_vec
return super().forward(disp_vec_projected)
def get_pinv_null_disp(self, disp_vec, jac_batch):
jac_batch_t = jac_batch.transpose(-2, -1)
J_J_t = torch.matmul(jac_batch, jac_batch_t)
J_pinv = jac_batch_t @ torch.inverse(J_J_t + self.eps * self.task_I.expand_as(J_J_t))
J_pinv_J = torch.matmul(J_pinv, jac_batch)
null_proj = self.I.expand_as(J_pinv_J) - J_pinv_J
null_disp = torch.matmul(null_proj, disp_vec.unsqueeze(-1)).squeeze(-1)
return null_disp

View File

@@ -0,0 +1,79 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from typing import Optional
# Third Party
import torch
# CuRobo
from curobo.cuda_robot_model.types import SelfCollisionKinematicsConfig
from curobo.curobolib.geom import SelfCollisionDistance
# Local Folder
from .cost_base import CostBase, CostConfig
@dataclass
class SelfCollisionCostConfig(CostConfig):
self_collision_kin_config: Optional[SelfCollisionKinematicsConfig] = None
def __post_init__(self):
return super().__post_init__()
class SelfCollisionCost(CostBase, SelfCollisionCostConfig):
def __init__(self, config: SelfCollisionCostConfig):
SelfCollisionCostConfig.__init__(self, **vars(config))
CostBase.__init__(self)
self._batch_size = None
def update_batch_size(self, robot_spheres):
# Assuming n stays constant
# TODO: use collision buffer here?
if self._batch_size is None or self._batch_size != robot_spheres.shape:
b, h, n, k = robot_spheres.shape
self._out_distance = torch.zeros(
(b, h), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self._out_vec = torch.zeros(
(b, h, n, k), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self._batch_size = robot_spheres.shape
self._sparse_sphere_idx = torch.zeros(
(b, h, n), device=self.tensor_args.device, dtype=torch.uint8
)
def forward(self, robot_spheres):
self.update_batch_size(robot_spheres)
dist = SelfCollisionDistance.apply(
self._out_distance,
self._out_vec,
self._sparse_sphere_idx,
robot_spheres,
self.self_collision_kin_config.offset,
self.weight,
self.self_collision_kin_config.collision_matrix,
self.self_collision_kin_config.thread_location,
self.self_collision_kin_config.thread_max,
self.self_collision_kin_config.checks_per_thread,
# False,
self.self_collision_kin_config.experimental_kernel,
self.return_loss,
)
if self.classify:
dist = torch.where(dist > 0, dist + 1.0, dist)
return dist

View File

@@ -0,0 +1,76 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from typing import Optional
# Third Party
import torch
# CuRobo
from curobo.rollout.dynamics_model.kinematic_model import TimeTrajConfig
# Local Folder
from .cost_base import CostBase, CostConfig
@dataclass
class StopCostConfig(CostConfig):
max_limit: Optional[float] = None
max_nlimit: Optional[float] = None
dt_traj_params: Optional[TimeTrajConfig] = None
horizon: int = 1
def __post_init__(self):
return super().__post_init__()
class StopCost(CostBase, StopCostConfig):
def __init__(self, config: StopCostConfig):
StopCostConfig.__init__(self, **vars(config))
CostBase.__init__(self)
sum_matrix = torch.tril(
torch.ones(
(self.horizon, self.horizon),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
).T
traj_dt = self.tensor_args.to_device(self.dt_traj_params.get_dt_array(self.horizon))
if self.max_nlimit is not None:
# every timestep max acceleration:
sum_matrix = torch.tril(
torch.ones(
(self.horizon, self.horizon),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
).T
delta_vel = traj_dt * self.max_nlimit
self.max_vel = (sum_matrix @ delta_vel).unsqueeze(-1)
elif self.max_limit is not None:
sum_matrix = torch.tril(
torch.ones(
(self.horizon, self.horizon),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
).T
delta_vel = torch.ones_like(traj_dt) * self.max_limit
self.max_vel = (sum_matrix @ delta_vel).unsqueeze(-1)
def forward(self, vels):
vel_abs = torch.abs(vels)
vel_abs = torch.nn.functional.relu(vel_abs - self.max_vel)
cost = self.weight * (torch.sum(vel_abs**2, dim=-1))
return cost

View File

@@ -0,0 +1,45 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Third Party
import torch
# Local Folder
from .cost_base import CostBase, CostConfig
@torch.jit.script
def st_cost(ee_pos_batch, vec_weight, weight):
ee_plus_one = torch.roll(ee_pos_batch, 1, dims=1)
xdot_current = ee_pos_batch - ee_plus_one + 1e-8
err_vec = vec_weight * xdot_current / 0.02
error = torch.sum(torch.square(err_vec), dim=-1)
# compute distance vector
cost = weight * error
return cost
class StraightLineCost(CostBase):
def __init__(self, config: CostConfig):
CostBase.__init__(self, config)
self.vel_idxs = torch.arange(
self.dof, 2 * self.dof, dtype=torch.long, device=self.tensor_args.device
)
self.I = torch.eye(self.dof, device=self.tensor_args.device, dtype=self.tensor_args.dtype)
def forward(self, ee_pos_batch):
cost = st_cost(ee_pos_batch, self.vec_weight, self.weight)
return cost

View File

@@ -0,0 +1,116 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Third Party
import torch
# Local Folder
from .cost_base import CostBase
@torch.jit.script
def squared_sum(cost: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
# return weight * torch.square(torch.linalg.norm(cost, dim=-1, ord=1))
# return weight * torch.sum(torch.square(cost), dim=-1)
# return torch.sum(torch.abs(cost) * weight, dim=-1)
return torch.sum(torch.square(cost) * weight, dim=-1)
@torch.jit.script
def run_squared_sum(
cost: torch.Tensor, weight: torch.Tensor, run_weight: torch.Tensor
) -> torch.Tensor:
# return torch.sum(torch.abs(cost)* weight * run_weight.unsqueeze(-1), dim=-1)
## below is smaller compute but more kernels
return torch.sum(torch.square(cost) * weight * run_weight.unsqueeze(-1), dim=-1)
# return torch.sum(torch.square(cost), dim=-1) * weight * run_weight
@torch.jit.script
def backward_squared_sum(cost_vec, w):
return 2.0 * w * cost_vec # * g_out.unsqueeze(-1)
# return w * g_out.unsqueeze(-1)
@torch.jit.script
def backward_run_squared_sum(cost_vec, w, r_w):
return 2.0 * w * r_w.unsqueeze(-1) * cost_vec # * g_out.unsqueeze(-1)
# return w * r_w.unsqueeze(-1) * cost_vec * g_out.unsqueeze(-1)
class SquaredSum(torch.autograd.Function):
@staticmethod
def forward(
ctx,
cost_vec,
weight,
):
cost = squared_sum(cost_vec, weight)
ctx.save_for_backward(cost_vec, weight)
return cost
@staticmethod
def backward(ctx, grad_out_cost):
(cost_vec, w) = ctx.saved_tensors
c_grad = None
if ctx.needs_input_grad[0]:
c_grad = backward_squared_sum(cost_vec, w)
return c_grad, None
class RunSquaredSum(torch.autograd.Function):
@staticmethod
def forward(
ctx,
cost_vec,
weight,
run_weight,
):
cost = run_squared_sum(cost_vec, weight, run_weight)
ctx.save_for_backward(cost_vec, weight, run_weight)
return cost
@staticmethod
def backward(ctx, grad_out_cost):
(cost_vec, w, r_w) = ctx.saved_tensors
c_grad = None
if ctx.needs_input_grad[0]:
c_grad = backward_run_squared_sum(cost_vec, w, r_w)
return c_grad, None, None
class ZeroCost(CostBase):
"""Zero Cost"""
def forward(self, x, goal_dist):
err = x
if self.max_value is not None:
err = torch.nn.functional.relu(torch.abs(err) - self.max_value)
if self.hinge_value is not None:
err = torch.where(goal_dist <= self.hinge_value, err, self._z_scalar) # soft hinge
if self.threshold_value is not None:
err = torch.where(err <= self.distance_threshold, self._z_scalar, err)
if not self.terminal: # or self.run_weight is not None:
cost = SquaredSum.apply(err, self.weight)
else:
if self._run_weight_vec is None or self._run_weight_vec.shape[1] != err.shape[1]:
self._run_weight_vec = torch.ones(
(1, err.shape[1]), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self._run_weight_vec[:, 1:-1] *= self.run_weight
cost = RunSquaredSum.apply(
err, self.weight, self._run_weight_vec
) # cost * self._run_weight_vec
return cost

View File

@@ -0,0 +1,10 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#

View File

@@ -0,0 +1,905 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from typing import List
# Third Party
import torch
from packaging import version
# CuRobo
from curobo.curobolib.tensor_step import (
tensor_step_acc_fwd,
tensor_step_acc_idx_fwd,
tensor_step_pos_clique_bwd,
tensor_step_pos_clique_fwd,
tensor_step_pos_clique_idx_fwd,
)
from curobo.types.base import TensorDeviceType
from curobo.types.robot import JointState
def build_clique_matrix(horizon, dt, device="cpu", dtype=torch.float32):
diag_dt = torch.diag(1 / dt)
one_t = torch.ones(horizon - 1, device=device, dtype=dtype)
fd_mat_pos = torch.diag_embed(one_t, offset=-1)
fd_mat_vel = -1.0 * torch.diag_embed(one_t, offset=-1)
one_t = torch.ones(horizon - 1, device=device, dtype=dtype)
fd_mat_vel += torch.eye(horizon, device=device, dtype=dtype)
fd_mat_vel[0, 0] = 0.0
fd_mat_vel = diag_dt @ fd_mat_vel
fd_mat_acc = diag_dt @ fd_mat_vel.clone()
fd_mat = torch.cat((fd_mat_pos, fd_mat_vel, fd_mat_acc), dim=0)
return fd_mat
def build_fd_matrix(
horizon,
device="cpu",
dtype=torch.float32,
order=1,
PREV_STATE=False,
FULL_RANK=False,
SHIFT=False,
):
if PREV_STATE:
# build order 1 fd matrix of horizon+order size
fd1_mat = build_fd_matrix(horizon + order, device, dtype, order=1)
# multiply order times to get fd_order matrix [h+order, h+order]
fd_mat = fd1_mat
fd_single = fd_mat.clone()
for _ in range(order - 1):
fd_mat = fd_single @ fd_mat
# return [horizon,h+order]
fd_mat = -1.0 * fd_mat[:horizon, :]
# fd_mat = torch.zeros((horizon, horizon + order),device=device, dtype=dtype)
# one_t = torch.ones(horizon, device=device, dtype=dtype)
# fd_mat[:horizon, :horizon] = torch.diag_embed(one_t)
# print(torch.diag_embed(one_t, offset=1).shape, fd_mat.shape)
# fd_mat += - torch.diag_embed(one_t, offset=1)[:-1,:]
elif FULL_RANK:
fd_mat = torch.eye(horizon, device=device, dtype=dtype)
one_t = torch.ones(horizon // 2, device=device, dtype=dtype)
fd_mat[: horizon // 2, : horizon // 2] = torch.diag_embed(one_t)
fd_mat[: horizon // 2 + 1, : horizon // 2 + 1] += -torch.diag_embed(one_t, offset=1)
one_t = torch.ones(horizon // 2, device=device, dtype=dtype)
fd_mat[horizon // 2 :, horizon // 2 :] += -torch.diag_embed(one_t, offset=-1)
fd_mat[horizon // 2, horizon // 2] = 0.0
fd_mat[horizon // 2, horizon // 2 - 1] = -1.0
fd_mat[horizon // 2, horizon // 2 + 1] = 1.0
else:
fd_mat = torch.zeros((horizon, horizon), device=device, dtype=dtype)
if horizon > 1:
one_t = torch.ones(horizon - 1, device=device, dtype=dtype)
if not SHIFT:
fd_mat[: horizon - 1, : horizon - 1] = -1.0 * torch.diag_embed(one_t)
fd_mat += torch.diag_embed(one_t, offset=1)
else:
fd_mat[1:, : horizon - 1] = -1.0 * torch.diag_embed(one_t)
fd_mat[1:, 1:] += torch.diag_embed(one_t)
fd_og = fd_mat.clone()
for _ in range(order - 1):
fd_mat = fd_og @ fd_mat
# if order > 1:
# #print(order, fd_mat)
# for i in range(order):
# fd_mat[i,:] /= (2**(i+2))
# #print(order, fd_mat[order])
# #print(order, fd_mat)
# fd_mat[:order]
# if order > 1:
# fd_mat[:order-1, :] = 0.0
# recreate this as a sparse tensor?
# print(fd_mat)
# sparse_indices = []
# sparse_values = []
# for i in range(horizon-1):
# sparse_indices.extend([[i,i], [i,i+1]])
# sparse_values.extend([-1.0, 1.0])
# sparse_indices.extend([[horizon-1, horizon-1]])
# sparse_values.extend([0.0])
# fd_kernel = torch.sparse_coo_tensor(torch.tensor(sparse_indices).t(),
# torch.tensor(sparse_values), device=device, dtype=dtype)
# fd_mat = fd_kernel.to_dense()
return fd_mat
def build_int_matrix(horizon, diagonal=0, device="cpu", dtype=torch.float32, order=1, traj_dt=None):
integrate_matrix = torch.tril(
torch.ones((horizon, horizon), device=device, dtype=dtype), diagonal=diagonal
)
chain_list = [torch.eye(horizon, device=device, dtype=dtype)]
if traj_dt is None:
chain_list.extend([integrate_matrix for i in range(order)])
else:
diag_dt = torch.diag(traj_dt)
for _ in range(order):
chain_list.append(integrate_matrix)
chain_list.append(diag_dt)
if len(chain_list) == 1:
integrate_matrix = chain_list[0]
elif version.parse(torch.__version__) < version.parse("1.9.0"):
integrate_matrix = torch.chain_matmul(*chain_list)
else:
integrate_matrix = torch.linalg.multi_dot(chain_list)
return integrate_matrix
def build_start_state_mask(horizon, tensor_args: TensorDeviceType):
mask = torch.zeros((horizon, 1), device=tensor_args.device, dtype=tensor_args.dtype)
# n_mask = torch.eye(horizon, device=tensor_args.device, dtype=tensor_args.dtype)
n_mask = torch.diag_embed(
torch.ones((horizon - 1), device=tensor_args.device, dtype=tensor_args.dtype), offset=-1
)
mask[0, 0] = 1.0
# n_mask[0,0] = 0.0
return mask, n_mask
# @torch.jit.script
def tensor_step_jerk(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix=None):
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Optional[Tensor]) -> Tensor
# This is batch,n_dof
q = state[:, :n_dofs]
qd = state[:, n_dofs : 2 * n_dofs]
qdd = state[:, 2 * n_dofs : 3 * n_dofs]
diag_dt = torch.diag(dt_h)
# qd_new = act
# integrate velocities:
qdd_new = qdd + torch.matmul(integrate_matrix, torch.matmul(diag_dt, act))
qd_new = qd + torch.matmul(integrate_matrix, torch.matmul(diag_dt, qdd_new))
q_new = q + torch.matmul(integrate_matrix, torch.matmul(diag_dt, qd_new))
state_seq[:, :, :n_dofs] = q_new
state_seq[:, :, n_dofs : n_dofs * 2] = qd_new
state_seq[:, :, n_dofs * 2 : n_dofs * 3] = qdd_new
return state_seq
# @torch.jit.script
def euler_integrate(q_0, u, diag_dt, integrate_matrix):
# q_new = q_0 + torch.matmul(integrate_matrix, torch.matmul(diag_dt, u))
q_new = q_0 + torch.matmul(integrate_matrix, u)
# q_new = torch.addmm(q_0,integrate_matrix,torch.matmul(diag_dt, u))
return q_new
# @torch.jit.script
def tensor_step_acc(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix=None):
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Optional[Tensor]) -> Tensor
# This is batch,n_dof
q = state[..., :n_dofs]
qd = state[..., n_dofs : 2 * n_dofs]
qdd_new = act
diag_dt = torch.diag(dt_h)
diag_dt_2 = torch.diag(dt_h**2)
qd_new = euler_integrate(qd, qdd_new, diag_dt, integrate_matrix)
q_new = euler_integrate(q, qd_new, diag_dt, integrate_matrix)
state_seq[..., n_dofs * 2 : n_dofs * 3] = qdd_new
state_seq[..., n_dofs : n_dofs * 2] = qd_new
# state_seq[:,1:, n_dofs: n_dofs * 2] = qd_new[:,:-1,:]
# state_seq[:,0:1, n_dofs: n_dofs * 2] = qd
# state_seq[:,1:, :n_dofs] = q_new[:,:-1,:] #+ 0.5 * torch.matmul(diag_dt_2,qdd_new)
state_seq[..., :n_dofs] = q_new # + 0.5 * torch.matmul(diag_dt_2,qdd_new)
# state_seq[:,0:1, :n_dofs] = q #state[...,:n_dofs]
return state_seq
@torch.jit.script
def jit_tensor_step_pos_clique_contiguous(pos_act, start_position, mask, n_mask, fd_1, fd_2, fd_3):
state_position = (start_position.unsqueeze(1).transpose(1, 2) @ mask.transpose(0, 1)) + (
pos_act.transpose(1, 2) @ n_mask.transpose(0, 1)
)
# state_position = mask @ start_position.unsqueeze(1) + n_mask @ pos_act
# print(state_position.shape, fd_1.shape)
# # below 3 can be done in parallel:
state_vel = (state_position @ fd_1.transpose(0, 1)).transpose(1, 2).contiguous()
state_acc = (state_position @ fd_2.transpose(0, 1)).transpose(1, 2).contiguous()
state_jerk = (state_position @ fd_3.transpose(0, 1)).transpose(1, 2).contiguous()
state_position = state_position.transpose(1, 2).contiguous()
return state_position, state_vel, state_acc, state_jerk
@torch.jit.script
def jit_tensor_step_pos_clique(pos_act, start_position, mask, n_mask, fd_1, fd_2, fd_3):
state_position = mask @ start_position.unsqueeze(1) + n_mask @ pos_act
state_vel = fd_1 @ state_position
state_acc = fd_2 @ state_position
state_jerk = fd_3 @ state_position
return state_position, state_vel, state_acc, state_jerk
@torch.jit.script
def jit_backward_pos_clique(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2, fd_3):
p_grad = (
grad_p
+ (fd_3).transpose(-1, -2) @ grad_j
+ (fd_2).transpose(-1, -2) @ grad_a
+ (fd_1).transpose(-1, -2) @ grad_v
)
u_grad = (n_mask).transpose(-1, -2) @ p_grad
# u_grad = n_mask @ p_grad
# p_grad = fd_3 @ grad_j + fd_2 @ grad_a + fd_1 @ grad_v + grad_p
# u_grad = n_mask @ p_grad
return u_grad
@torch.jit.script
def jit_backward_pos_clique_contiguous(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2, fd_3):
p_grad = grad_p + (
grad_j.transpose(-1, -2) @ fd_3
+ grad_a.transpose(-1, -2) @ fd_2
+ grad_v.transpose(-1, -2) @ fd_1
).transpose(-1, -2)
# u_grad = (n_mask).transpose(-1, -2) @ p_grad
u_grad = (p_grad.transpose(-1, -2) @ n_mask).transpose(-1, -2).contiguous()
return u_grad
class CliqueTensorStep(torch.autograd.Function):
@staticmethod
def forward(
ctx,
u_act,
start_position,
mask,
n_mask,
fd_1,
fd_2,
fd_3,
):
state_position, state_vel, state_acc, state_jerk = jit_tensor_step_pos_clique(
u_act, start_position, mask, n_mask, fd_1, fd_2, fd_3
)
ctx.save_for_backward(n_mask, fd_1, fd_2, fd_3)
return state_position, state_vel, state_acc, state_jerk
@staticmethod
def backward(ctx, grad_out_p, grad_out_v, grad_out_a, grad_out_j):
u_grad = None
(n_mask, fd_1, fd_2, fd_3) = ctx.saved_tensors
if ctx.needs_input_grad[0]:
u_grad = jit_backward_pos_clique(
grad_out_p, grad_out_v, grad_out_a, grad_out_j, n_mask, fd_1, fd_2, fd_3
)
return u_grad, None, None, None, None, None, None
class CliqueTensorStepKernel(torch.autograd.Function):
@staticmethod
def forward(
ctx,
u_act,
start_position,
start_velocity,
start_acceleration,
out_position,
out_velocity,
out_acceleration,
out_jerk,
traj_dt,
out_grad_position,
):
(
state_position,
state_velocity,
state_acceleration,
state_jerk,
) = tensor_step_pos_clique_fwd(
out_position,
out_velocity,
out_acceleration,
out_jerk,
u_act,
start_position,
start_velocity,
start_acceleration,
traj_dt,
out_position.shape[0],
out_position.shape[1],
out_position.shape[-1],
)
ctx.save_for_backward(traj_dt, out_grad_position)
return state_position, state_velocity, state_acceleration, state_jerk
@staticmethod
def backward(ctx, grad_out_p, grad_out_v, grad_out_a, grad_out_j):
u_grad = None
if ctx.needs_input_grad[0]:
(traj_dt, out_grad_position) = ctx.saved_tensors
u_grad = tensor_step_pos_clique_bwd(
out_grad_position,
grad_out_p,
grad_out_v,
grad_out_a,
grad_out_j,
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
)
return (
u_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class CliqueTensorStepIdxKernel(torch.autograd.Function):
@staticmethod
def forward(
ctx,
u_act,
start_position,
start_velocity,
start_acceleration,
start_idx,
out_position,
out_velocity,
out_acceleration,
out_jerk,
traj_dt,
out_grad_position,
):
(
state_position,
state_velocity,
state_acceleration,
state_jerk,
) = tensor_step_pos_clique_idx_fwd(
out_position,
out_velocity,
out_acceleration,
out_jerk,
u_act,
start_position,
start_velocity,
start_acceleration,
start_idx,
traj_dt,
out_position.shape[0],
out_position.shape[1],
out_position.shape[-1],
)
ctx.save_for_backward(traj_dt, out_grad_position)
return state_position, state_velocity, state_acceleration, state_jerk
@staticmethod
def backward(ctx, grad_out_p, grad_out_v, grad_out_a, grad_out_j):
u_grad = None
if ctx.needs_input_grad[0]:
(traj_dt, out_grad_position) = ctx.saved_tensors
u_grad = tensor_step_pos_clique_bwd(
out_grad_position,
grad_out_p,
grad_out_v,
grad_out_a,
grad_out_j,
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
)
return (
u_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
) # , None, None, None, None,None
class CliqueTensorStepCentralDifferenceKernel(torch.autograd.Function):
@staticmethod
def forward(
ctx,
u_act,
start_position,
start_velocity,
start_acceleration,
out_position,
out_velocity,
out_acceleration,
out_jerk,
traj_dt,
out_grad_position,
):
(
state_position,
state_velocity,
state_acceleration,
state_jerk,
) = tensor_step_pos_clique_fwd(
out_position,
out_velocity,
out_acceleration,
out_jerk,
u_act,
start_position,
start_velocity,
start_acceleration,
traj_dt,
out_position.shape[0],
out_position.shape[1],
out_position.shape[-1],
0,
)
ctx.save_for_backward(traj_dt, out_grad_position)
return state_position, state_velocity, state_acceleration, state_jerk
@staticmethod
def backward(ctx, grad_out_p, grad_out_v, grad_out_a, grad_out_j):
u_grad = None
if ctx.needs_input_grad[0]:
(traj_dt, out_grad_position) = ctx.saved_tensors
u_grad = tensor_step_pos_clique_bwd(
out_grad_position,
grad_out_p,
grad_out_v,
grad_out_a.contiguous(),
grad_out_j.contiguous(),
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
0,
)
return (
u_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class CliqueTensorStepIdxCentralDifferenceKernel(torch.autograd.Function):
@staticmethod
def forward(
ctx,
u_act,
start_position,
start_velocity,
start_acceleration,
start_idx,
out_position,
out_velocity,
out_acceleration,
out_jerk,
traj_dt,
out_grad_position,
):
(
state_position,
state_velocity,
state_acceleration,
state_jerk,
) = tensor_step_pos_clique_idx_fwd(
out_position,
out_velocity,
out_acceleration,
out_jerk,
u_act,
start_position,
start_velocity,
start_acceleration,
start_idx,
traj_dt,
out_position.shape[0],
out_position.shape[1],
out_position.shape[-1],
0,
)
ctx.save_for_backward(traj_dt, out_grad_position)
return state_position, state_velocity, state_acceleration, state_jerk
@staticmethod
def backward(ctx, grad_out_p, grad_out_v, grad_out_a, grad_out_j):
u_grad = None
if ctx.needs_input_grad[0]:
(traj_dt, out_grad_position) = ctx.saved_tensors
u_grad = tensor_step_pos_clique_bwd(
out_grad_position,
grad_out_p,
grad_out_v,
grad_out_a.contiguous(),
grad_out_j.contiguous(),
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
0,
)
return (
u_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
) # , None, None, None, None,None
class CliqueTensorStepCoalesceKernel(torch.autograd.Function):
@staticmethod
def forward(
ctx,
u_act,
start_position,
start_velocity,
start_acceleration,
out_position,
out_velocity,
out_acceleration,
out_jerk,
traj_dt,
out_grad_position,
):
state_position, state_velocity, state_acceleration, state_jerk = tensor_step_pos_clique_fwd(
out_position.transpose(-1, -2).contiguous(),
out_velocity.transpose(-1, -2).contiguous(),
out_acceleration.transpose(-1, -2).contiguous(),
out_jerk.transpose(-1, -2).contiguous(),
u_act.transpose(-1, -2).contiguous(),
start_position,
start_velocity,
start_acceleration,
traj_dt,
out_position.shape[0],
out_position.shape[1],
out_position.shape[-1],
)
ctx.save_for_backward(traj_dt, out_grad_position)
return (
state_position.transpose(-1, -2).contiguous(),
state_velocity.transpose(-1, -2).contiguous(),
state_acceleration.transpose(-1, -2).contiguous(),
state_jerk.transpose(-1, -2).contiguous(),
)
@staticmethod
def backward(ctx, grad_out_p, grad_out_v, grad_out_a, grad_out_j):
u_grad = None
(traj_dt, out_grad_position) = ctx.saved_tensors
if ctx.needs_input_grad[0]:
u_grad = tensor_step_pos_clique_bwd(
out_grad_position.transpose(-1, -2).contiguous(),
grad_out_p.transpose(-1, -2).contiguous(),
grad_out_v.transpose(-1, -2).contiguous(),
grad_out_a.transpose(-1, -2).contiguous(),
grad_out_j.transpose(-1, -2).contiguous(),
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
)
return (
u_grad.transpose(-1, -2).contiguous(),
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class AccelerationTensorStepKernel(torch.autograd.Function):
@staticmethod
def forward(
ctx,
u_act,
start_position,
start_velocity,
start_acceleration,
out_position,
out_velocity,
out_acceleration,
out_jerk,
traj_dt,
out_grad_position,
):
state_position, state_velocity, state_acceleration, state_jerk = tensor_step_acc_fwd(
out_position,
out_velocity,
out_acceleration,
out_jerk,
u_act,
start_position,
start_velocity,
start_acceleration,
traj_dt,
out_position.shape[0],
out_position.shape[1],
out_position.shape[-1],
)
ctx.save_for_backward(traj_dt, out_grad_position)
return state_position, state_velocity, state_acceleration, state_jerk
@staticmethod
def backward(ctx, grad_out_p, grad_out_v, grad_out_a, grad_out_j):
u_grad = None
(traj_dt, out_grad_position) = ctx.saved_tensors
if ctx.needs_input_grad[0]:
raise NotImplementedError()
u_grad = tensor_step_pos_clique_bwd(
out_grad_position,
grad_out_p,
grad_out_v,
grad_out_a,
grad_out_j,
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
)
return u_grad, None, None, None, None, None, None, None, None, None
class AccelerationTensorStepIdxKernel(torch.autograd.Function):
@staticmethod
def forward(
ctx,
u_act,
start_position,
start_velocity,
start_acceleration,
start_idx,
out_position,
out_velocity,
out_acceleration,
out_jerk,
traj_dt,
out_grad_position,
):
state_position, state_velocity, state_acceleration, state_jerk = tensor_step_acc_idx_fwd(
out_position,
out_velocity,
out_acceleration,
out_jerk,
u_act,
start_position,
start_velocity,
start_acceleration,
start_idx,
traj_dt,
out_position.shape[0],
out_position.shape[1],
out_position.shape[-1],
)
ctx.save_for_backward(traj_dt, out_grad_position)
return state_position, state_velocity, state_acceleration, state_jerk
@staticmethod
def backward(ctx, grad_out_p, grad_out_v, grad_out_a, grad_out_j):
u_grad = None
(traj_dt, out_grad_position) = ctx.saved_tensors
if ctx.needs_input_grad[0]:
raise NotImplementedError()
u_grad = tensor_step_pos_clique_bwd(
out_grad_position,
grad_out_p,
grad_out_v,
grad_out_a,
grad_out_j,
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
)
return u_grad, None, None, None, None, None, None, None, None, None, None
# @torch.jit.script
def tensor_step_pos_clique(
state: JointState,
act: torch.Tensor,
state_seq: JointState,
mask_matrix: List[torch.Tensor],
fd_matrix: List[torch.Tensor],
):
(
state_seq.position,
state_seq.velocity,
state_seq.acceleration,
state_seq.jerk,
) = CliqueTensorStep.apply(
act,
state.position,
mask_matrix[0],
mask_matrix[1],
fd_matrix[0],
fd_matrix[1],
fd_matrix[2],
)
return state_seq
def step_acc_semi_euler(state, act, diag_dt, n_dofs, integrate_matrix):
q = state[..., :n_dofs]
qd = state[..., n_dofs : 2 * n_dofs]
qdd_new = act
# diag_dt = torch.diag(dt_h)
qd_new = euler_integrate(qd, qdd_new, diag_dt, integrate_matrix)
q_new = euler_integrate(q, qd_new, diag_dt, integrate_matrix)
state_seq = torch.cat((q_new, qd_new, qdd_new), dim=-1)
return state_seq
# @torch.jit.script
def tensor_step_acc_semi_euler(
state, act, state_seq, diag_dt, integrate_matrix, integrate_matrix_pos
):
# type: (Tensor, Tensor, Tensor, int, Tensor, Optional[Tensor]) -> Tensor
# This is batch,n_dof
state = state.unsqueeze(1)
q = state.position # [..., :n_dofs]
qd = state.velocity # [..., n_dofs : 2 * n_dofs]
qdd_new = act
# diag_dt = torch.diag(dt_h)
qd_new = euler_integrate(qd, qdd_new, diag_dt, integrate_matrix)
q_new = euler_integrate(q, qd_new, diag_dt, integrate_matrix_pos)
state_seq.acceleration = qdd_new
state_seq.velocity = qd_new
state_seq.position = q_new
return state_seq
# @torch.jit.script
def tensor_step_vel(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix):
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Tensor) -> Tensor
# This is batch,n_dof
state_seq[:, 0:1, : n_dofs * 3] = state
q = state[..., :n_dofs]
qd_new = act[:, :-1, :]
# integrate velocities:
dt_diag = torch.diag(dt_h)
state_seq[:, 1:, n_dofs : n_dofs * 2] = qd_new
qd = state_seq[:, :, n_dofs : n_dofs * 2]
q_new = euler_integrate(q, qd, dt_diag, integrate_matrix)
state_seq[:, :, :n_dofs] = q_new
qdd = (torch.diag(1 / dt_h)) @ fd_matrix @ qd
state_seq[:, 1:, n_dofs * 2 : n_dofs * 3] = qdd[:, :-1, :]
return state_seq
# @torch.jit.script
def tensor_step_pos(state, act, state_seq, fd_matrix):
# This is batch,n_dof
state_seq.position[:, 0, :] = state.position
state_seq.velocity[:, 0, :] = state.velocity
state_seq.acceleration[:, 0, :] = state.acceleration
# integrate velocities:
state_seq.position[:, 1:] = act[:, :-1, :]
qd = fd_matrix @ state_seq.position # [:, :, :n_dofs]
state_seq.velocity[:, 1:] = qd[:, :-1, :] # qd_new
qdd = fd_matrix @ state_seq.velocity # [:, :, n_dofs : n_dofs * 2]
state_seq.acceleration[:, 1:] = qdd[:, :-1, :]
# jerk = fd_matrix @ state_seq.acceleration
return state_seq
# @torch.jit.script
def tensor_step_pos_ik(act, state_seq):
state_seq.position = act
return state_seq
def tensor_linspace(start_tensor, end_tensor, steps=10):
dist = end_tensor - start_tensor
interpolate_matrix = (
torch.ones((steps), device=start_tensor.device, dtype=start_tensor.dtype) / steps
)
cum_matrix = torch.cumsum(interpolate_matrix, dim=0)
interp_tensor = start_tensor + cum_matrix * dist
return interp_tensor
def sum_matrix(h, int_steps, tensor_args):
sum_mat = torch.zeros(((h - 1) * int_steps, h), **vars(tensor_args))
for i in range(h - 1):
sum_mat[i * int_steps : i * int_steps + int_steps, i] = 1.0
# hack:
# sum_mat[-1, -1] = 1.0
return sum_mat
def interpolate_kernel(h, int_steps, tensor_args: TensorDeviceType):
mat = torch.zeros(
((h - 1) * (int_steps), h), device=tensor_args.device, dtype=tensor_args.dtype
)
delta = torch.arange(0, int_steps, device=tensor_args.device, dtype=tensor_args.dtype) / (
int_steps - 1
)
for i in range(h - 1):
mat[i * int_steps : i * int_steps + int_steps, i] = delta.flip(0)
mat[i * int_steps : i * int_steps + int_steps, i + 1] = delta
return mat
def action_interpolate_kernel(h, int_steps, tensor_args: TensorDeviceType):
mat = torch.zeros(
((h - 1) * (int_steps), h), device=tensor_args.device, dtype=tensor_args.dtype
)
delta = torch.arange(0, int_steps - 2, device=tensor_args.device, dtype=tensor_args.dtype) / (
int_steps - 1.0 - 2
)
for i in range(h - 1):
mat[i * int_steps : i * (int_steps) + int_steps - 1 - 2, i] = delta.flip(0)[1:]
mat[i * int_steps : i * (int_steps) + int_steps - 1 - 2, i + 1] = delta[1:]
mat[-3:, 1] = 1.0
return mat

View File

@@ -0,0 +1,605 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Union
# Third Party
import torch
import torch.autograd.profiler as profiler
# CuRobo
from curobo.cuda_robot_model.cuda_robot_model import CudaRobotModel
from curobo.rollout.dynamics_model.tensor_step import (
TensorStepAccelerationKernel,
TensorStepPosition,
TensorStepPositionClique,
TensorStepPositionCliqueKernel,
TensorStepPositionTeleport,
)
from curobo.types.base import TensorDeviceType
from curobo.types.enum import StateType
from curobo.types.math import Pose
from curobo.types.robot import JointState, RobotConfig
from curobo.util.helpers import list_idx_if_not_none
from curobo.util.logger import log_error, log_info
from curobo.util.state_filter import FilterConfig, JointStateFilter
@dataclass
class TimeTrajConfig:
base_dt: float
base_ratio: float
max_dt: float
def get_dt_array(self, num_points: int):
dt_array = [self.base_dt] * int(self.base_ratio * num_points)
smooth_blending = torch.linspace(
self.base_dt,
self.max_dt,
steps=int((1 - self.base_ratio) * num_points),
).tolist()
dt_array += smooth_blending
if len(dt_array) != num_points:
dt_array.insert(0, dt_array[0])
return dt_array
def update_dt(
self,
all_dt: float = None,
base_dt: float = None,
max_dt: float = None,
base_ratio: float = None,
):
if all_dt is not None:
self.base_dt = all_dt
self.max_dt = all_dt
return
if base_dt is not None:
self.base_dt = base_dt
if base_ratio is not None:
self.base_ratio = base_ratio
if max_dt is not None:
self.max_dt = max_dt
@dataclass
class KinematicModelState(Sequence):
# TODO: subclass this from State
state_seq: JointState
ee_pos_seq: Optional[torch.Tensor] = None
ee_quat_seq: Optional[torch.Tensor] = None
robot_spheres: Optional[torch.Tensor] = None
link_pos_seq: Optional[torch.Tensor] = None
link_quat_seq: Optional[torch.Tensor] = None
lin_jac_seq: Optional[torch.Tensor] = None
ang_jac_seq: Optional[torch.Tensor] = None
link_names: Optional[List[str]] = None
def __getitem__(self, idx):
d_list = [
self.state_seq,
self.ee_pos_seq,
self.ee_quat_seq,
self.robot_spheres,
self.link_pos_seq,
self.link_quat_seq,
self.lin_jac_seq,
self.ang_jac_seq,
]
idx_vals = list_idx_if_not_none(d_list, idx)
return KinematicModelState(*idx_vals, link_names=self.link_names)
def __len__(self):
return len(self.state_seq)
@property
def ee_pose(self) -> Pose:
return Pose(self.ee_pos_seq, self.ee_quat_seq, normalize_rotation=False)
@property
def link_pose(self):
if self.link_names is not None:
link_pos_seq = self.link_pos_seq.contiguous()
link_quat_seq = self.link_quat_seq.contiguous()
link_poses = {}
for i, v in enumerate(self.link_names):
link_poses[v] = Pose(
link_pos_seq[..., i, :], link_quat_seq[..., i, :], normalize_rotation=False
)
else:
link_poses = None
return link_poses
@dataclass(frozen=False)
class KinematicModelConfig:
robot_config: RobotConfig
dt_traj_params: TimeTrajConfig
tensor_args: TensorDeviceType
vel_scale: float = 1.0
state_estimation_variance: float = 0.0
batch_size: int = 1
horizon: int = 5
control_space: StateType = StateType.ACCELERATION
state_filter_cfg: Optional[FilterConfig] = None
teleport_mode: bool = False
return_full_act_buffer: bool = False
state_finite_difference_mode: str = "BACKWARD"
filter_robot_command: bool = False
# tensor_step_type: TensorStepType = TensorStepType.ACCELERATION
@staticmethod
def from_dict(
data_dict_in, robot_cfg: Union[Dict, RobotConfig], tensor_args=TensorDeviceType()
):
data_dict = deepcopy(data_dict_in)
if isinstance(robot_cfg, dict):
robot_cfg = RobotConfig.from_dict(robot_cfg, tensor_args)
data_dict["robot_config"] = robot_cfg
data_dict["dt_traj_params"] = TimeTrajConfig(**data_dict["dt_traj_params"])
data_dict["control_space"] = StateType[data_dict["control_space"]]
data_dict["state_filter_cfg"] = FilterConfig.from_dict(
data_dict["state_filter_cfg"]["filter_coeff"],
enable=data_dict["state_filter_cfg"]["enable"],
dt=data_dict["dt_traj_params"].base_dt,
control_space=data_dict["control_space"],
tensor_args=tensor_args,
teleport_mode=data_dict["teleport_mode"],
)
return KinematicModelConfig(**data_dict, tensor_args=tensor_args)
class KinematicModel(KinematicModelConfig):
def __init__(
self,
kinematic_model_config: KinematicModelConfig,
):
super().__init__(**vars(kinematic_model_config))
self.dt = self.dt_traj_params.base_dt
self.robot_model = CudaRobotModel(self.robot_config.kinematics)
# update cspace to store joint names in the order given by robot model:
self.n_dofs = self.robot_model.get_dof()
self._use_clique = True
self._use_bmm_tensor_step = False
self._use_clique_kernel = True
self.d_state = 4 * self.n_dofs # + 1
self.d_action = self.n_dofs
# Variables for enforcing joint limits
self.joint_names = self.robot_model.joint_names
self.joint_limits = self.robot_model.get_joint_limits()
# #pre-allocating memory for rollouts
self.state_seq = JointState.from_state_tensor(
torch.zeros(
self.batch_size,
self.horizon,
self.d_state,
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
),
dof=int(self.d_state / 3),
)
self.Z = torch.tensor([0.0], device=self.tensor_args.device, dtype=self.tensor_args.dtype)
dt_array = self.dt_traj_params.get_dt_array(self.horizon)
self.traj_dt = torch.tensor(
dt_array, dtype=self.tensor_args.dtype, device=self.tensor_args.device
)
# TODO: choose tensor integration type here:
if self.control_space == StateType.ACCELERATION:
# self._rollout_step_fn = TensorStepAcceleration(self.tensor_args, self._dt_h)
# self._cmd_step_fn = TensorStepAcceleration(self.tensor_args, self.traj_dt)
self._rollout_step_fn = TensorStepAccelerationKernel(
self.tensor_args, self.traj_dt, self.n_dofs
)
self._cmd_step_fn = TensorStepAccelerationKernel(
self.tensor_args, self.traj_dt, self.n_dofs
)
elif self.control_space == StateType.VELOCITY:
raise NotImplementedError()
elif self.control_space == StateType.JERK:
raise NotImplementedError()
elif self.control_space == StateType.POSITION:
if self.teleport_mode:
self._rollout_step_fn = TensorStepPositionTeleport(self.tensor_args)
self._cmd_step_fn = TensorStepPositionTeleport(self.tensor_args)
else:
if self._use_clique:
if self._use_clique_kernel:
if self.state_finite_difference_mode == "BACKWARD":
finite_difference = -1
elif self.state_finite_difference_mode == "CENTRAL":
finite_difference = 0
else:
log_error(
"unknown state finite difference mode: "
+ self.state_finite_difference_mode
)
self._rollout_step_fn = TensorStepPositionCliqueKernel(
self.tensor_args,
self.traj_dt,
self.n_dofs,
finite_difference_mode=finite_difference,
filter_velocity=False,
filter_acceleration=False,
filter_jerk=False,
)
self._cmd_step_fn = TensorStepPositionCliqueKernel(
self.tensor_args,
self.traj_dt,
self.n_dofs,
finite_difference_mode=finite_difference,
filter_velocity=False,
filter_acceleration=self.filter_robot_command,
filter_jerk=self.filter_robot_command,
)
else:
self._rollout_step_fn = TensorStepPositionClique(
self.tensor_args, self.traj_dt
)
self._cmd_step_fn = TensorStepPositionClique(self.tensor_args, self.traj_dt)
else:
self._rollout_step_fn = TensorStepPosition(self.tensor_args, self.traj_dt)
self._cmd_step_fn = TensorStepPosition(self.tensor_args, self.traj_dt)
self.update_batch_size(self.batch_size)
self.state_filter = JointStateFilter(self.state_filter_cfg)
self._robot_cmd_state_seq = JointState.zeros(
(1, self.horizon, self.d_action), self.tensor_args
)
self._cmd_batch_size = -1
if not self.teleport_mode:
self._max_joint_vel = (
self.get_state_bounds()
.velocity.view(2, self.d_action)[1, :]
.reshape(1, 1, self.d_action)
) - 0.2
self._max_joint_acc = self.get_state_bounds().acceleration[1, :] - 0.2
self._max_joint_jerk = self.get_state_bounds().jerk[1, :] - 0.2
def update_traj_dt(
self,
dt: Union[float, torch.Tensor],
base_dt: Optional[float] = None,
max_dt: Optional[float] = None,
base_ratio: Optional[float] = None,
):
self.dt_traj_params.update_dt(dt, base_dt, max_dt, base_ratio)
dt_array = self.dt_traj_params.get_dt_array(self.horizon)
self.traj_dt[:] = torch.tensor(
dt_array, dtype=self.tensor_args.dtype, device=self.tensor_args.device
)
self._cmd_step_fn.update_dt(self.traj_dt)
self._rollout_step_fn.update_dt(self.traj_dt)
def get_next_state(self, curr_state: torch.Tensor, act: torch.Tensor, dt):
"""Does a single step from the current state
Args:
curr_state: current state
act: action
dt: time to integrate
Returns:
next_state
TODO: Move this into tensorstep class?
"""
if self.control_space == StateType.JERK:
curr_state[2 * self.n_dofs : 3 * self.n_dofs] = (
curr_state[self.n_dofs : 2 * self.n_dofs] + act * dt
)
curr_state[self.n_dofs : 2 * self.n_dofs] = (
curr_state[self.n_dofs : 2 * self.n_dofs]
+ curr_state[self.n_dofs * 2 : self.n_dofs * 3] * dt
)
curr_state[: self.n_dofs] = (
curr_state[: self.n_dofs] + curr_state[self.n_dofs : 2 * self.n_dofs] * dt
)
elif self.control_space == StateType.ACCELERATION:
curr_state[2 * self.n_dofs : 3 * self.n_dofs] = act
curr_state[self.n_dofs : 2 * self.n_dofs] = (
curr_state[self.n_dofs : 2 * self.n_dofs]
+ curr_state[self.n_dofs * 2 : self.n_dofs * 3] * dt
)
curr_state[: self.n_dofs] = (
curr_state[: self.n_dofs]
+ curr_state[self.n_dofs : 2 * self.n_dofs] * dt
+ 0.5 * act * dt * dt
)
elif self.control_space == StateType.VELOCITY:
curr_state[2 * self.n_dofs : 3 * self.n_dofs] = 0.0
curr_state[self.n_dofs : 2 * self.n_dofs] = act * dt
curr_state[: self.n_dofs] = (
curr_state[: self.n_dofs] + curr_state[self.n_dofs : 2 * self.n_dofs] * dt
)
elif self.control_space == StateType.POSITION:
curr_state[2 * self.n_dofs : 3 * self.n_dofs] = 0.0
curr_state[1 * self.n_dofs : 2 * self.n_dofs] = 0.0
curr_state[: self.n_dofs] = act
return curr_state
def tensor_step(
self,
state: JointState,
act: torch.Tensor,
state_seq: JointState,
state_idx: Optional[torch.Tensor] = None,
) -> JointState:
"""
Args:
state: [1,N]
act: [H,N]
todo:
Integration with variable dt along trajectory
"""
state_seq = self._rollout_step_fn.forward(state, act, state_seq, state_idx)
return state_seq
def robot_cmd_tensor_step(
self,
state: JointState,
act: torch.Tensor,
state_seq: JointState,
state_idx: Optional[torch.Tensor] = None,
) -> JointState:
"""
Args:
state: [1,N]
act: [H,N]
todo:
Integration with variable dt along trajectory
"""
state_seq = self._cmd_step_fn.forward(state, act, state_seq, state_idx)
state_seq.joint_names = self.joint_names
return state_seq
def update_cmd_batch_size(self, batch_size):
if self._cmd_batch_size != batch_size:
self._robot_cmd_state_seq = JointState.zeros(
(batch_size, self.horizon, self.d_action), self.tensor_args
)
self._cmd_step_fn.update_batch_size(batch_size, self.horizon)
self._cmd_batch_size = batch_size
def update_batch_size(self, batch_size, force_update=False):
if self.batch_size != batch_size:
# TODO: Remove tensor recreation upon force update?
self.state_seq = JointState.zeros(
(batch_size, self.horizon, self.d_action), self.tensor_args
)
log_info("Updating state_seq buffer reference (created new tensor)")
# print("Creating new tensor")
if force_update:
self.state_seq = self.state_seq.detach()
self._rollout_step_fn.update_batch_size(batch_size, self.horizon, force_update)
self.batch_size = batch_size
def forward(
self,
start_state: JointState,
act_seq: torch.Tensor,
start_state_idx: Optional[torch.Tensor] = None,
) -> KinematicModelState:
# filter state if needed:
start_state_shaped = start_state # .unsqueeze(1)
# batch_size, horizon, d_act = act_seq.shape
batch_size = act_seq.shape[0]
self.update_batch_size(batch_size, force_update=act_seq.requires_grad)
state_seq = self.state_seq
curr_batch_size = self.batch_size
num_traj_points = self.horizon
with profiler.record_function("tensor_step"):
# forward step with step matrix:
state_seq = self.tensor_step(start_state_shaped, act_seq, state_seq, start_state_idx)
shape_tup = (curr_batch_size * num_traj_points, self.n_dofs)
with profiler.record_function("fk + jacobian"):
(
ee_pos_seq,
ee_quat_seq,
lin_jac_seq,
ang_jac_seq,
link_pos_seq,
link_quat_seq,
link_spheres,
) = self.robot_model.forward(state_seq.position.view(shape_tup))
link_pos_seq = link_pos_seq.view(
((curr_batch_size, num_traj_points, link_pos_seq.shape[1], 3))
)
link_quat_seq = link_quat_seq.view(
((curr_batch_size, num_traj_points, link_quat_seq.shape[1], 4))
)
link_spheres = link_spheres.view(
(curr_batch_size, num_traj_points, link_spheres.shape[1], link_spheres.shape[-1])
)
ee_pos_seq = ee_pos_seq.view((curr_batch_size, num_traj_points, 3))
ee_quat_seq = ee_quat_seq.view((curr_batch_size, num_traj_points, 4))
if lin_jac_seq is not None:
lin_jac_seq = lin_jac_seq.view((curr_batch_size, num_traj_points, 3, self.n_dofs))
if ang_jac_seq is not None:
ang_jac_seq = ang_jac_seq.view((curr_batch_size, num_traj_points, 3, self.n_dofs))
state = KinematicModelState(
state_seq,
ee_pos_seq,
ee_quat_seq,
link_spheres,
link_pos_seq,
link_quat_seq,
lin_jac_seq,
ang_jac_seq,
link_names=self.robot_model.link_names,
)
return state
def integrate_action(self, act_seq):
if self.action_order == 0:
return act_seq
nth_act_seq = self._integrate_matrix_nth @ act_seq
return nth_act_seq
def integrate_action_step(self, act, dt):
for i in range(self.action_order):
act = act * dt
return act
def filter_robot_state(self, current_state: JointState):
filtered_state = self.state_filter.filter_joint_state(current_state)
return filtered_state
@torch.no_grad()
def get_robot_command(
self,
current_state: JointState,
act_seq: torch.Tensor,
shift_steps: int = 1,
state_idx: Optional[torch.Tensor] = None,
) -> JointState:
if self.return_full_act_buffer:
if act_seq.shape[0] != self._cmd_batch_size:
self.update_cmd_batch_size(act_seq.shape[0])
full_state = self.robot_cmd_tensor_step(
current_state,
act_seq,
self._robot_cmd_state_seq,
state_idx,
)
return full_state
if shift_steps == 1:
if self.control_space == StateType.POSITION:
act_step = act_seq[..., 0, :].clone()
else:
act_step = act_seq[..., 0, :].clone()
cmd = self.state_filter.integrate_action(act_step, current_state)
return cmd
# get the first timestep in action buffer
cmd = current_state.clone()
for i in range(shift_steps):
act_step = act_seq[..., i, :]
# we integrate the action with the current belief:
cmd = self.state_filter.integrate_action(act_step, cmd)
if i == 0:
cmd_buffer = cmd.clone()
else:
cmd_buffer = cmd_buffer.stack(cmd)
return cmd_buffer
@property
def action_bound_lows(self):
if self.control_space == StateType.POSITION:
# use joint limits:
return self.joint_limits.position[0]
if self.control_space == StateType.VELOCITY:
# use joint limits:
return self.joint_limits.velocity[0]
if self.control_space == StateType.ACCELERATION:
# use joint limits:
return self.joint_limits.acceleration[0]
@property
def action_bound_highs(self):
if self.control_space == StateType.POSITION:
# use joint limits:
return self.joint_limits.position[1]
if self.control_space == StateType.VELOCITY:
# use joint limits:
return self.joint_limits.velocity[1]
if self.control_space == StateType.ACCELERATION:
# use joint limits:
return self.joint_limits.acceleration[1]
@property
def init_action_mean(self):
# output should be d_action * horizon
if self.control_space == StateType.POSITION:
# use joint limits:
return self.retract_config.unsqueeze(0).repeat(self.horizon, 1)
if self.control_space == StateType.VELOCITY or self.control_space == StateType.ACCELERATION:
# use joint limits:
return self.retract_config.unsqueeze(0).repeat(self.horizon, 1) * 0.0
@property
def retract_config(self):
return self.robot_model.kinematics_config.cspace.retract_config
@property
def cspace_distance_weight(self):
return self.robot_model.kinematics_config.cspace.cspace_distance_weight
@property
def null_space_weight(self):
return self.robot_model.kinematics_config.cspace.null_space_weight
@property
def max_acceleration(self):
return self.get_state_bounds().acceleration[1, 0].item()
@property
def max_jerk(self):
return self.get_state_bounds().jerk[1, 0].item()
def get_state_bounds(self):
joint_limits = self.robot_model.get_joint_limits()
return joint_limits
def get_action_from_state(self, state: JointState) -> torch.Tensor:
if self.control_space == StateType.POSITION:
return state.position
if self.control_space == StateType.VELOCITY:
return state.velocity
if self.control_space == StateType.ACCELERATION:
return state.acceleration
def get_state_from_action(
self,
start_state: JointState,
act_seq: torch.Tensor,
state_idx: Optional[torch.Tensor] = None,
) -> JointState:
"""Compute State sequence from an action trajectory
Args:
start_state (JointState): _description_
act_seq (torch.Tensor): _description_
Returns:
JointState: _description_
"""
if act_seq.shape[0] != self._cmd_batch_size:
self.update_cmd_batch_size(act_seq.shape[0])
full_state = self.robot_cmd_tensor_step(
start_state,
act_seq,
self._robot_cmd_state_seq,
state_idx,
)
return full_state

View File

@@ -0,0 +1,33 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from abc import ABC, abstractmethod
class DynamicsModelBase(ABC):
def __init__(self):
pass
@abstractmethod
def forward(self, start_state, act_seq, *args):
pass
@abstractmethod
def get_next_state(self, currend_state, act, dt):
pass
@abstractmethod
def filter_robot_state(self, current_state):
pass
@abstractmethod
def get_robot_command(self, current_state, act_seq):
pass

View File

@@ -0,0 +1,524 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from abc import abstractmethod
from enum import Enum
from typing import Optional
# Third Party
import torch
# CuRobo
from curobo.types.base import TensorDeviceType
from curobo.types.robot import JointState
# Local Folder
from .integration_utils import (
AccelerationTensorStepIdxKernel,
AccelerationTensorStepKernel,
CliqueTensorStepCentralDifferenceKernel,
CliqueTensorStepIdxCentralDifferenceKernel,
CliqueTensorStepIdxKernel,
CliqueTensorStepKernel,
build_fd_matrix,
build_int_matrix,
build_start_state_mask,
tensor_step_acc_semi_euler,
tensor_step_pos,
tensor_step_pos_clique,
)
class TensorStepType(Enum):
POSITION_TELEPORT = 0
POSITION_CLIQUE_KERNEL = 1
VELOCITY = 2 # Not implemented
ACCELERATION_KERNEL = 3
JERK = 4 # Not implemented
POSITION = 5 # deprecated
POSITION_CLIQUE = 6 # deprecated
ACCELERATION = 7 # deprecated
class TensorStepBase:
def __init__(self, tensor_args: TensorDeviceType) -> None:
self.batch_size = -1
self.horizon = -1
self.tensor_args = tensor_args
self._diag_dt = None
self._inv_dt_h = None
def update_dt(self, dt: float):
self._dt_h[:] = dt
if self._inv_dt_h is not None:
self._inv_dt_h[:] = 1.0 / dt
@abstractmethod
def update_batch_size(
self,
batch_size: Optional[int] = None,
horizon: Optional[int] = None,
force_update: bool = False,
) -> None:
self.horizon = horizon
self.batch_size = batch_size
@abstractmethod
def forward(
self,
start_state: JointState,
u_act: torch.Tensor,
out_state_seq: JointState,
start_state_idx: Optional[torch.Tensor] = None,
) -> JointState:
pass
class TensorStepAcceleration(TensorStepBase):
def __init__(self, tensor_args: TensorDeviceType, dt_h: torch.Tensor) -> None:
super().__init__(tensor_args)
self._dt_h = dt_h
self._diag_dt_h = torch.diag(self._dt_h)
self._integrate_matrix_pos = None
self._integrate_matrix_vel = None
def update_batch_size(
self,
batch_size: Optional[int] = None,
horizon: Optional[int] = None,
force_update: bool = False,
) -> None:
if self.horizon != horizon:
self._integrate_matrix_pos = (
build_int_matrix(
horizon,
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
diagonal=0,
)
@ self._diag_dt_h
)
self._integrate_matrix_vel = self._integrate_matrix_pos @ self._diag_dt_h
return super().update_batch_size(batch_size, horizon)
def forward(
self,
start_state: JointState,
u_act: torch.Tensor,
out_state_seq: JointState,
start_state_idx: Optional[torch.Tensor] = None,
) -> JointState:
if start_state_idx is None:
state_seq = tensor_step_acc_semi_euler(
start_state,
u_act,
out_state_seq,
self._diag_dt_h,
self._integrate_matrix_vel,
self._integrate_matrix_pos,
)
else:
state_seq = tensor_step_acc_semi_euler(
start_state[start_state_idx],
u_act,
out_state_seq,
self._diag_dt_h,
self._integrate_matrix_vel,
self._integrate_matrix_pos,
)
return state_seq
class TensorStepPositionTeleport(TensorStepBase):
def __init__(self, tensor_args: TensorDeviceType) -> None:
super().__init__(tensor_args)
def forward(
self,
start_state: JointState,
u_act: torch.Tensor,
out_state_seq: JointState,
start_state_idx: Optional[torch.Tensor] = None,
) -> JointState:
out_state_seq.position = u_act
return out_state_seq
class TensorStepPosition(TensorStepBase):
def __init__(self, tensor_args: TensorDeviceType, dt_h: torch.Tensor) -> None:
super().__init__(tensor_args)
self._dt_h = dt_h
# self._diag_dt_h = torch.diag(1 / self._dt_h)
self._fd_matrix = None
def update_dt(self, dt: float):
super().update_dt(dt)
self._fd_matrix = build_fd_matrix(
self.horizon,
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
order=1,
)
self._fd_matrix = torch.diag(1.0 / self._dt_h) @ self._fd_matrix
def update_batch_size(
self,
batch_size: Optional[int] = None,
horizon: Optional[int] = None,
force_update: bool = False,
) -> None:
if horizon != self.horizon:
self._fd_matrix = build_fd_matrix(
horizon,
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
order=1,
)
self._fd_matrix = torch.diag(1.0 / self._dt_h) @ self._fd_matrix
# self._fd_matrix = self._diag_dt_h @ self._fd_matrix
return super().update_batch_size(batch_size, horizon)
def forward(
self,
start_state: JointState,
u_act: torch.Tensor,
out_state_seq: JointState,
start_state_idx: Optional[torch.Tensor] = None,
) -> JointState:
if start_state_idx is None:
state_seq = tensor_step_pos(start_state, u_act, out_state_seq, self._fd_matrix)
else:
state_seq = tensor_step_pos(
start_state[start_state_idx], u_act, out_state_seq, self._fd_matrix
)
return state_seq
class TensorStepPositionClique(TensorStepBase):
def __init__(self, tensor_args: TensorDeviceType, dt_h: torch.Tensor) -> None:
super().__init__(tensor_args)
self._dt_h = dt_h
self._inv_dt_h = 1.0 / dt_h
self._fd_matrix = None
self._start_mask_matrix = None
def update_dt(self, dt: float):
super().update_dt(dt)
self._fd_matrix = []
for i in range(3):
self._fd_matrix.append(
build_fd_matrix(
self.horizon,
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
order=i + 1,
SHIFT=True,
)
)
self._diag_dt_h = torch.diag(self._inv_dt_h)
self._fd_matrix[0] = self._diag_dt_h @ self._fd_matrix[0]
self._fd_matrix[1] = self._diag_dt_h**2 @ self._fd_matrix[1]
self._fd_matrix[2] = self._diag_dt_h**3 @ self._fd_matrix[2]
def update_batch_size(
self,
batch_size: Optional[int] = None,
horizon: Optional[int] = None,
force_update: bool = False,
) -> None:
if self.horizon != horizon:
self._fd_matrix = []
for i in range(3):
self._fd_matrix.append(
build_fd_matrix(
horizon,
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
order=i + 1,
SHIFT=True,
)
)
self._diag_dt_h = torch.diag(self._inv_dt_h)
self._fd_matrix[0] = self._diag_dt_h @ self._fd_matrix[0]
self._fd_matrix[1] = self._diag_dt_h**2 @ self._fd_matrix[1]
self._fd_matrix[2] = self._diag_dt_h**3 @ self._fd_matrix[2]
self._start_mask_matrix = list(build_start_state_mask(horizon, self.tensor_args))
return super().update_batch_size(batch_size, horizon)
def forward(
self,
start_state: JointState,
u_act: torch.Tensor,
out_state_seq: JointState,
start_state_idx: Optional[torch.Tensor] = None,
) -> JointState:
if start_state_idx is None:
state_seq = tensor_step_pos_clique(
start_state, u_act, out_state_seq, self._start_mask_matrix, self._fd_matrix
)
else:
state_seq = tensor_step_pos_clique(
start_state[start_state_idx],
u_act,
out_state_seq,
self._start_mask_matrix,
self._fd_matrix,
)
return state_seq
class TensorStepAccelerationKernel(TensorStepBase):
def __init__(self, tensor_args: TensorDeviceType, dt_h: torch.Tensor, dof: int) -> None:
super().__init__(tensor_args)
self._dt_h = dt_h
self._u_grad = None
self.dof = dof
def update_batch_size(
self,
batch_size: Optional[int] = None,
horizon: Optional[int] = None,
force_update: bool = False,
) -> None:
if batch_size != self.batch_size or horizon != self.horizon:
self._u_grad = torch.zeros(
(batch_size, horizon, self.dof),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
if force_update:
self._u_grad = self._u_grad.detach()
return super().update_batch_size(batch_size, horizon)
def forward(
self,
start_state: JointState,
u_act: torch.Tensor,
out_state_seq: JointState,
start_state_idx: Optional[torch.Tensor] = None,
) -> JointState:
if start_state_idx is None:
(
out_state_seq.position,
out_state_seq.velocity,
out_state_seq.acceleration,
out_state_seq.jerk,
) = AccelerationTensorStepKernel.apply(
u_act,
start_state.position, # .contiguous(),
start_state.velocity, # .contiguous(),
start_state.acceleration, # .contiguous(),
out_state_seq.position, # .contiguous(),
out_state_seq.velocity, # .contiguous(),
out_state_seq.acceleration, # .contiguous(),
out_state_seq.jerk, # .contiguous(),
self._dt_h,
self._u_grad,
)
else:
(
out_state_seq.position,
out_state_seq.velocity,
out_state_seq.acceleration,
out_state_seq.jerk,
) = AccelerationTensorStepIdxKernel.apply(
u_act,
start_state.position, # .contiguous(),
start_state.velocity, # .contiguous(),
start_state.acceleration, # .contiguous(),
start_state_idx,
out_state_seq.position, # .contiguous(),
out_state_seq.velocity, # .contiguous(),
out_state_seq.acceleration, # .contiguous(),
out_state_seq.jerk, # .contiguous(),
self._dt_h,
self._u_grad,
)
return out_state_seq
class TensorStepPositionCliqueKernel(TensorStepBase):
def __init__(
self,
tensor_args: TensorDeviceType,
dt_h: torch.Tensor,
dof: int,
finite_difference_mode: int = -1,
filter_velocity: bool = False,
filter_acceleration: bool = False,
filter_jerk: bool = False,
) -> None:
super().__init__(tensor_args)
self._dt_h = dt_h
self._inv_dt_h = 1.0 / dt_h
self._u_grad = None
self.dof = dof
self._fd_mode = finite_difference_mode
self._filter_velocity = filter_velocity
self._filter_acceleration = filter_acceleration
self._filter_jerk = filter_jerk
if self._filter_velocity or self._filter_acceleration or self._filter_jerk:
kernel = self.tensor_args.to_device([[[0.06136, 0.24477, 0.38774, 0.24477, 0.06136]]])
self._sma = torch.nn.functional.conv1d
weights = kernel
self._sma_kernel = weights
# self._sma = torch.nn.AvgPool1d(kernel_size=5, stride=2, padding=1).to(
# device=self.tensor_args.device
# )
def update_batch_size(
self,
batch_size: Optional[int] = None,
horizon: Optional[int] = None,
force_update: bool = False,
) -> None:
if batch_size != self.batch_size or horizon != self.horizon:
self._u_grad = torch.zeros(
(batch_size, horizon, self.dof),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
if force_update:
self._u_grad = self._u_grad.detach()
return super().update_batch_size(batch_size, horizon)
def forward(
self,
start_state: JointState,
u_act: torch.Tensor,
out_state_seq: JointState,
start_state_idx: Optional[torch.Tensor] = None,
) -> JointState:
if start_state_idx is None:
if self._fd_mode == -1:
(
out_state_seq.position,
out_state_seq.velocity,
out_state_seq.acceleration,
out_state_seq.jerk,
) = CliqueTensorStepKernel.apply(
u_act,
start_state.position, # .contiguous(),
start_state.velocity, # .contiguous(),
start_state.acceleration, # .contiguous(),
out_state_seq.position, # .contiguous(),
out_state_seq.velocity, # .contiguous(),
out_state_seq.acceleration, # .contiguous(),
out_state_seq.jerk, # .contiguous(),
self._inv_dt_h,
self._u_grad,
)
else:
(
out_state_seq.position,
out_state_seq.velocity,
out_state_seq.acceleration,
out_state_seq.jerk,
) = CliqueTensorStepCentralDifferenceKernel.apply(
u_act,
start_state.position, # .contiguous(),
start_state.velocity, # .contiguous(),
start_state.acceleration, # .contiguous(),
out_state_seq.position, # .contiguous(),
out_state_seq.velocity, # .contiguous(),
out_state_seq.acceleration, # .contiguous(),
out_state_seq.jerk, # .contiguous(),
self._inv_dt_h,
self._u_grad,
)
else:
if self._fd_mode == -1:
(
out_state_seq.position,
out_state_seq.velocity,
out_state_seq.acceleration,
out_state_seq.jerk,
) = CliqueTensorStepIdxKernel.apply(
u_act,
start_state.position, # .contiguous(),
start_state.velocity, # .contiguous(),
start_state.acceleration, # .contiguous(),
start_state_idx,
out_state_seq.position, # .contiguous(),
out_state_seq.velocity, # .contiguous(),
out_state_seq.acceleration, # .contiguous(),
out_state_seq.jerk, # .contiguous(),
self._inv_dt_h,
self._u_grad,
)
else:
(
out_state_seq.position,
out_state_seq.velocity,
out_state_seq.acceleration,
out_state_seq.jerk,
) = CliqueTensorStepIdxCentralDifferenceKernel.apply(
u_act,
start_state.position, # .contiguous(),
start_state.velocity, # .contiguous(),
start_state.acceleration, # .contiguous(),
start_state_idx,
out_state_seq.position, # .contiguous(),
out_state_seq.velocity, # .contiguous(),
out_state_seq.acceleration, # .contiguous(),
out_state_seq.jerk, # .contiguous(),
self._inv_dt_h,
self._u_grad,
)
if self._filter_velocity:
out_state_seq.velocity = self.filter_signal(out_state_seq.velocity)
if self._filter_acceleration:
out_state_seq.acceleration = self.filter_signal(out_state_seq.acceleration)
if self._filter_jerk:
out_state_seq.jerk = self.filter_signal(out_state_seq.jerk)
return out_state_seq
def filter_signal(self, signal: torch.Tensor):
return filter_signal_jit(signal, self._sma_kernel)
b, h, dof = signal.shape
new_signal = (
self._sma(
signal.transpose(-1, -2).reshape(b * dof, 1, h), self._sma_kernel, padding="same"
)
.view(b, dof, h)
.transpose(-1, -2)
.contiguous()
)
return new_signal
@torch.jit.script
def filter_signal_jit(signal, kernel):
b, h, dof = signal.shape
new_signal = (
torch.nn.functional.conv1d(
signal.transpose(-1, -2).reshape(b * dof, 1, h), kernel, padding="same"
)
.view(b, dof, h)
.transpose(-1, -2)
.contiguous()
)
return new_signal

View File

@@ -0,0 +1,563 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# import torch
from __future__ import annotations
# Standard Library
from abc import abstractmethod, abstractproperty
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence
# Third Party
import torch
# CuRobo
from curobo.types.base import TensorDeviceType
from curobo.types.math import Pose
from curobo.types.robot import CSpaceConfig, State
from curobo.types.tensor import (
T_BDOF,
T_DOF,
T_BHDOF_float,
T_BHValue_float,
T_BValue_bool,
T_BValue_float,
)
from curobo.util.helpers import list_idx_if_not_none
from curobo.util.sample_lib import HaltonGenerator
from curobo.util.tensor_util import copy_tensor
@dataclass
class RolloutMetrics(Sequence):
cost: Optional[T_BValue_float] = None
constraint: Optional[T_BValue_float] = None
feasible: Optional[T_BValue_bool] = None
state: Optional[State] = None
def __getitem__(self, idx):
d_list = [self.cost, self.constraint, self.feasible, self.state]
idx_vals = list_idx_if_not_none(d_list, idx)
return RolloutMetrics(idx_vals[0], idx_vals[1], idx_vals[2], idx_vals[3])
def __len__(self):
if self.cost is not None:
return self.cost.shape[0]
else:
return -1
def clone(self, clone_state=False):
if clone_state:
raise NotImplementedError()
return RolloutMetrics(
cost=None if self.cost is None else self.cost.clone(),
constraint=None if self.constraint is None else self.constraint.clone(),
feasible=None if self.feasible is None else self.feasible.clone(),
state=None if self.state is None else self.state,
)
@dataclass
class Trajectory:
actions: T_BHDOF_float
costs: T_BHValue_float
state: Optional[State] = None
debug: Optional[dict] = None
@dataclass
class Goal(Sequence):
"""Goal data class used to update optimization target.
#NOTE:
We can parallelize Goal in two ways:
1. Solve for current_state, pose pair in same environment
2. Solve for current_state, pose pair in different environment
For case (1), we use batch_pose_idx to find the memory address of the
current_state, pose pair while keeping batch_world_idx = [0]
For case (2), we add a batch_world_idx[0,1,2..].
"""
name: str = "goal"
goal_state: Optional[State] = None
goal_pose: Pose = Pose()
links_goal_pose: Optional[Dict[str, Pose]] = None
current_state: Optional[State] = None
retract_state: Optional[T_DOF] = None
batch: int = -1 # NOTE: add another variable for size of index tensors?
# this should also contain a batch index tensor:
batch_pose_idx: Optional[torch.Tensor] = None # shape: [batch]
batch_goal_state_idx: Optional[torch.Tensor] = None
batch_retract_state_idx: Optional[torch.Tensor] = None
batch_current_state_idx: Optional[torch.Tensor] = None # shape: [batch]
batch_enable_idx: Optional[torch.Tensor] = None # shape: [batch, n]
batch_world_idx: Optional[torch.Tensor] = None # shape: [batch, n]
update_batch_idx_buffers: bool = True
n_goalset: int = 1 # NOTE: This currently does not get updated if goal_pose is updated later.
def __getitem__(self, idx):
d_list = [
self.goal_state,
self.goal_pose,
self.current_state,
self.retract_state,
self.batch_pose_idx,
self.batch_goal_state_idx,
self.batch_retract_state_idx,
self.batch_current_state_idx,
self.batch_enable_idx,
self.batch_world_idx,
]
idx_vals = list_idx_if_not_none(d_list, idx)
return Goal(
name=self.name,
batch=self.batch,
n_goalset=self.n_goalset,
goal_state=idx_vals[0],
goal_pose=idx_vals[1],
current_state=idx_vals[2],
retract_state=idx_vals[3],
batch_pose_idx=idx_vals[4],
batch_goal_state_idx=idx_vals[5],
batch_retract_state_idx=idx_vals[6],
batch_current_state_idx=idx_vals[7],
batch_enable_idx=idx_vals[8],
batch_world_idx=idx_vals[9],
)
def __len__(self):
return self.batch
def __post_init__(self):
self._update_batch_size()
if self.goal_pose.position is not None:
if self.batch_pose_idx is None:
self.batch_pose_idx = torch.arange(
0, self.batch, 1, device=self.goal_pose.position.device, dtype=torch.int32
).unsqueeze(-1)
self.n_goalset = self.goal_pose.n_goalset
if self.current_state is not None:
if self.batch_current_state_idx is None:
self.batch_current_state_idx = torch.arange(
0,
self.current_state.position.shape[0],
1,
device=self.current_state.position.device,
dtype=torch.int32,
).unsqueeze(-1)
if self.retract_state is not None:
if self.batch_retract_state_idx is None:
self.batch_retract_state_idx = torch.arange(
0,
self.retract_state.shape[0],
1,
device=self.retract_state.device,
dtype=torch.int32,
).unsqueeze(-1)
def _update_batch_size(self):
if self.goal_pose.position is not None:
self.batch = self.goal_pose.batch
elif self.goal_state is not None:
self.batch = self.goal_state.position.shape[0]
elif self.current_state is not None:
self.batch = self.current_state.position.shape[0]
def repeat_seeds(self, num_seeds: int):
# across seeds, the data is the same, so could we just expand batch_idx
# TODO:
goal_pose = goal_state = current_state = links_goal_pose = retract_state = None
batch_enable_idx = batch_pose_idx = batch_world_idx = batch_current_state_idx = None
batch_retract_state_idx = batch_goal_state_idx = None
if self.links_goal_pose is not None:
links_goal_pose = self.links_goal_pose
if self.goal_pose is not None:
goal_pose = self.goal_pose
# goal_pose = self.goal_pose.repeat_seeds(num_seeds)
if self.goal_state is not None:
goal_state = self.goal_state # .repeat_seeds(num_seeds)
if self.current_state is not None:
current_state = self.current_state # .repeat_seeds(num_seeds)
if self.retract_state is not None:
retract_state = self.retract_state
# repeat seeds for indexing:
if self.batch_pose_idx is not None:
batch_pose_idx = self._tensor_repeat_seeds(self.batch_pose_idx, num_seeds)
if self.batch_goal_state_idx is not None:
batch_goal_state_idx = self._tensor_repeat_seeds(self.batch_goal_state_idx, num_seeds)
if self.batch_retract_state_idx is not None:
batch_retract_state_idx = self._tensor_repeat_seeds(
self.batch_retract_state_idx, num_seeds
)
if self.batch_enable_idx is not None:
batch_enable_idx = self._tensor_repeat_seeds(self.batch_enable_idx, num_seeds)
if self.batch_world_idx is not None:
batch_world_idx = self._tensor_repeat_seeds(self.batch_world_idx, num_seeds)
if self.batch_current_state_idx is not None:
batch_current_state_idx = self._tensor_repeat_seeds(
self.batch_current_state_idx, num_seeds
)
return Goal(
goal_state=goal_state,
goal_pose=goal_pose,
current_state=current_state,
retract_state=retract_state,
batch_pose_idx=batch_pose_idx,
batch_world_idx=batch_world_idx,
batch_enable_idx=batch_enable_idx,
batch_current_state_idx=batch_current_state_idx,
batch_retract_state_idx=batch_retract_state_idx,
batch_goal_state_idx=batch_goal_state_idx,
links_goal_pose=links_goal_pose,
)
def _tensor_repeat_seeds(self, tensor, num_seeds):
return tensor_repeat_seeds(tensor, num_seeds)
def apply_kernel(self, kernel_mat):
# For each seed in optimization, we use kernel_mat to transform to many parallel goals
# This can be modified to just multiply self.batch and update self.batch by the shape of self.batch
# TODO: add other elements
goal_pose = goal_state = current_state = links_goal_pose = None
batch_enable_idx = batch_pose_idx = batch_world_idx = batch_current_state_idx = None
batch_retract_state_idx = batch_goal_state_idx = None
if self.links_goal_pose is not None:
links_goal_pose = self.links_goal_pose
if self.goal_pose is not None:
goal_pose = self.goal_pose # .apply_kernel(kernel_mat)
if self.goal_state is not None:
goal_state = self.goal_state # .apply_kernel(kernel_mat)
if self.current_state is not None:
current_state = self.current_state # .apply_kernel(kernel_mat)
if self.batch_enable_idx is not None:
batch_enable_idx = kernel_mat @ self.batch_enable_idx
if self.batch_retract_state_idx is not None:
batch_retract_state_idx = (
kernel_mat @ self.batch_retract_state_idx.to(dtype=torch.float32)
).to(dtype=torch.int32)
if self.batch_goal_state_idx is not None:
batch_goal_state_idx = (
kernel_mat @ self.batch_goal_state_idx.to(dtype=torch.float32)
).to(dtype=torch.int32)
if self.batch_current_state_idx is not None:
batch_current_state_idx = (
kernel_mat @ self.batch_current_state_idx.to(dtype=torch.float32)
).to(dtype=torch.int32)
if self.batch_pose_idx is not None:
batch_pose_idx = (kernel_mat @ self.batch_pose_idx.to(dtype=torch.float32)).to(
dtype=torch.int32
)
if self.batch_world_idx is not None:
batch_world_idx = (kernel_mat @ self.batch_world_idx.to(dtype=torch.float32)).to(
dtype=torch.int32
)
return Goal(
goal_state=goal_state,
goal_pose=goal_pose,
current_state=current_state,
batch_pose_idx=batch_pose_idx,
batch_enable_idx=batch_enable_idx,
batch_world_idx=batch_world_idx,
batch_current_state_idx=batch_current_state_idx,
batch_goal_state_idx=batch_goal_state_idx,
batch_retract_state_idx=batch_retract_state_idx,
links_goal_pose=links_goal_pose,
)
def to(self, tensor_args: TensorDeviceType):
if self.goal_pose is not None:
self.goal_pose = self.goal_pose.to(tensor_args)
if self.goal_state is not None:
self.goal_state = self.goal_state.to(**vars(tensor_args))
if self.current_state is not None:
self.current_state = self.current_state.to(**vars(tensor_args))
return self
def copy_(self, goal: Goal, update_idx_buffers: bool = True):
"""Copy data from another goal object.
Args:
goal (Goal): _description_
Raises:
NotImplementedError: _description_
NotImplementedError: _description_
Returns:
_type_: _description_
"""
self.goal_pose = self._copy_buffer(self.goal_pose, goal.goal_pose)
self.goal_state = self._copy_buffer(self.goal_state, goal.goal_state)
self.retract_state = self._copy_tensor(self.retract_state, goal.retract_state)
self.current_state = self._copy_buffer(self.current_state, goal.current_state)
if goal.links_goal_pose is not None:
if self.links_goal_pose is None:
self.links_goal_pose = goal.links_goal_pose
else:
for k in goal.links_goal_pose.keys():
self.links_goal_pose[k] = self._copy_buffer(
self.links_goal_pose[k], goal.links_goal_pose[k]
)
self._update_batch_size()
# copy pose indices as well?
if goal.update_batch_idx_buffers and update_idx_buffers:
self.batch_pose_idx = self._copy_tensor(self.batch_pose_idx, goal.batch_pose_idx)
self.batch_enable_idx = self._copy_tensor(self.batch_enable_idx, goal.batch_enable_idx)
self.batch_world_idx = self._copy_tensor(self.batch_world_idx, goal.batch_world_idx)
self.batch_current_state_idx = self._copy_tensor(
self.batch_current_state_idx, goal.batch_current_state_idx
)
self.batch_retract_state_idx = self._copy_tensor(
self.batch_retract_state_idx, goal.batch_retract_state_idx
)
self.batch_goal_state_idx = self._copy_tensor(
self.batch_goal_state_idx, goal.batch_goal_state_idx
)
def _copy_buffer(self, ref_buffer, buffer):
if buffer is not None:
if ref_buffer is not None:
ref_buffer = ref_buffer.copy_(buffer)
else:
ref_buffer = buffer.clone()
return ref_buffer
def _copy_tensor(self, ref_buffer, buffer):
if buffer is not None:
if ref_buffer is not None:
if not copy_tensor(buffer, ref_buffer):
ref_buffer = buffer.clone()
else:
ref_buffer = buffer.clone()
return ref_buffer
def get_batch_goal_state(self):
return self.goal_state[self.batch_pose_idx[:, 0]]
def create_index_buffers(
self,
batch_size: int,
batch_env: bool,
batch_retract: bool,
num_seeds: int,
tensor_args: TensorDeviceType,
):
new_goal = Goal.create_idx(batch_size, batch_env, batch_retract, num_seeds, tensor_args)
new_goal.copy_(self, update_idx_buffers=False)
return new_goal
@classmethod
def create_idx(
cls,
pose_batch_size: int,
batch_env: bool,
batch_retract: bool,
num_seeds: int,
tensor_args: TensorDeviceType,
):
batch_pose_idx = torch.arange(
0, pose_batch_size, 1, device=tensor_args.device, dtype=torch.int32
).unsqueeze(-1)
if batch_env:
batch_world_idx = batch_pose_idx.clone()
else:
batch_world_idx = 0 * batch_pose_idx
if batch_retract:
batch_retract_state_idx = batch_pose_idx.clone()
else:
batch_retract_state_idx = 0 * batch_pose_idx.clone()
batch_currernt_state_idx = batch_pose_idx.clone()
batch_goal_state_idx = batch_pose_idx.clone()
g = Goal(
batch_pose_idx=batch_pose_idx,
batch_retract_state_idx=batch_retract_state_idx,
batch_world_idx=batch_world_idx,
batch_current_state_idx=batch_currernt_state_idx,
batch_goal_state_idx=batch_goal_state_idx,
)
g_seeds = g.repeat_seeds(num_seeds)
return g_seeds
@dataclass
class RolloutConfig:
tensor_args: TensorDeviceType
class RolloutBase:
def __init__(self, config: Optional[RolloutConfig] = None):
self.start_state = None
self.batch_size = 1
self._metrics_cuda_graph_init = False
self.cu_metrics_graph = None
self._rollout_constraint_cuda_graph_init = False
self.cu_rollout_constraint_graph = None
if config is not None:
self.tensor_args = config.tensor_args
def _init_after_config_load(self):
self.act_sample_gen = HaltonGenerator(
self.d_action,
self.tensor_args,
up_bounds=self.action_bound_highs,
low_bounds=self.action_bound_lows,
seed=1312,
)
@abstractmethod
def cost_fn(self, state: State):
return
@abstractmethod
def constraint_fn(
self, state: State, out_metrics: Optional[RolloutMetrics] = None
) -> RolloutMetrics:
return
@abstractmethod
def convergence_fn(
self, state: State, out_metrics: Optional[RolloutMetrics] = None
) -> RolloutMetrics:
return
def get_metrics(self, state: State):
out_metrics = self.constraint_fn(state)
out_metrics = self.convergence_fn(state, out_metrics)
return out_metrics
def get_metrics_cuda_graph(self, state: State):
return get_metrics(state)
def rollout_fn(self, act):
pass
def current_cost(self, current_state):
pass
@abstractmethod
def update_params(self, goal: Goal):
return
def __call__(self, act: T_BHDOF_float) -> Trajectory:
return self.rollout_fn(act)
@abstractproperty
def action_bounds(self):
return self.tensor_args.to_device(
torch.stack([self.action_bound_lows, self.action_bound_highs])
)
@abstractmethod
def filter_robot_state(self, current_state: State) -> State:
return current_state
@abstractmethod
def get_robot_command(
self, current_state, act_seq, shift_steps: int = 1, state_idx: Optional[torch.Tensor] = None
):
return act_seq
def reset_seed(self):
self.act_sample_gen.reset()
def reset(self):
return True
@abstractproperty
def d_action(self) -> int:
raise NotImplementedError
@abstractproperty
def action_bound_lows(self):
return 1
@abstractproperty
def action_bound_highs(self):
return 1
@abstractproperty
def dt(self):
return 0.1
@property
def horizon(self) -> int:
raise NotImplementedError
def update_start_state(self, start_state: torch.Tensor):
if self.start_state is None:
self.start_state = start_state
copy_tensor(start_state, self.start_state)
@abstractmethod
def get_init_action_seq(self):
raise NotImplementedError
@property
def state_bounds(self) -> Dict[str, List[float]]:
pass
# sample random actions
# @abstractmethod
def sample_random_actions(self, n: int = 0):
act_rand = self.act_sample_gen.get_samples(n, bounded=True)
return act_rand
# how to map act_seq to state?
# rollout for feasibility?
@abstractmethod
def rollout_constraint(self, act_seq: torch.Tensor) -> RolloutMetrics:
# get state by rolling out
# get feasibility:
pass
def reset_cuda_graph(self):
self._metrics_cuda_graph_init = False
if self.cu_metrics_graph is not None:
self.cu_metrics_graph.reset()
self._rollout_constraint_cuda_graph_init = False
if self.cu_rollout_constraint_graph is not None:
self.cu_rollout_constraint_graph.reset()
@abstractmethod
def get_action_from_state(self, state: State):
pass
@abstractmethod
def get_state_from_action(
self, start_state: State, act_seq: torch.Tensor, state_idx: Optional[torch.Tensor] = None
):
pass
@abstractproperty
def cspace_config(self) -> CSpaceConfig:
pass
def get_full_dof_from_solution(self, q_js: JointState) -> JointState:
return q_js
@torch.jit.script
def tensor_repeat_seeds(tensor, num_seeds: int):
a = (
tensor.view(tensor.shape[0], 1, tensor.shape[-1])
.repeat(1, num_seeds, 1)
.view(tensor.shape[0] * num_seeds, tensor.shape[-1])
)
return a