constrained planning, robot segmentation
This commit is contained in:
@@ -39,7 +39,7 @@ from curobo.rollout.rollout_base import Goal, RolloutBase, RolloutConfig, Rollou
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import CSpaceConfig, RobotConfig
|
||||
from curobo.types.state import JointState
|
||||
from curobo.util.logger import log_info, log_warn
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.tensor_util import cat_sum
|
||||
|
||||
|
||||
@@ -366,6 +366,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
)
|
||||
cost_list.append(coll_cost)
|
||||
if return_list:
|
||||
|
||||
return cost_list
|
||||
cost = cat_sum(cost_list)
|
||||
return cost
|
||||
@@ -424,6 +425,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
out_metrics = self.constraint_fn(state)
|
||||
out_metrics.state = state
|
||||
out_metrics = self.convergence_fn(state, out_metrics)
|
||||
out_metrics.cost = self.cost_fn(state)
|
||||
return out_metrics
|
||||
|
||||
def get_metrics_cuda_graph(self, state: JointState):
|
||||
@@ -451,6 +453,8 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
with torch.cuda.graph(self.cu_metrics_graph, stream=s):
|
||||
self._cu_out_metrics = self.get_metrics(self._cu_metrics_state_in)
|
||||
self._metrics_cuda_graph_init = True
|
||||
if self._cu_metrics_state_in.position.shape != state.position.shape:
|
||||
log_error("cuda graph changed")
|
||||
self._cu_metrics_state_in.copy_(state)
|
||||
self.cu_metrics_graph.replay()
|
||||
out_metrics = self._cu_out_metrics
|
||||
@@ -462,17 +466,6 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
):
|
||||
if out_metrics is None:
|
||||
out_metrics = RolloutMetrics()
|
||||
if (
|
||||
self.convergence_cfg.null_space_cfg is not None
|
||||
and self.null_convergence.enabled
|
||||
and self._goal_buffer.batch_retract_state_idx is not None
|
||||
):
|
||||
out_metrics.cost = self.null_convergence.forward_target_idx(
|
||||
self._goal_buffer.retract_state,
|
||||
state.state_seq.position,
|
||||
self._goal_buffer.batch_retract_state_idx,
|
||||
)
|
||||
|
||||
return out_metrics
|
||||
|
||||
def _get_augmented_state(self, state: JointState) -> KinematicModelState:
|
||||
@@ -688,9 +681,11 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
act_seq = self.dynamics_model.init_action_mean.unsqueeze(0).repeat(self.batch_size, 1, 1)
|
||||
return act_seq
|
||||
|
||||
def reset_cuda_graph(self):
|
||||
def reset_shape(self):
|
||||
self._goal_idx_update = True
|
||||
super().reset_shape()
|
||||
|
||||
def reset_cuda_graph(self):
|
||||
super().reset_cuda_graph()
|
||||
|
||||
def get_action_from_state(self, state: JointState):
|
||||
|
||||
@@ -20,16 +20,16 @@ import torch.autograd.profiler as profiler
|
||||
from curobo.geom.sdf.world import WorldCollision
|
||||
from curobo.rollout.cost.cost_base import CostConfig
|
||||
from curobo.rollout.cost.dist_cost import DistCost, DistCostConfig
|
||||
from curobo.rollout.cost.pose_cost import PoseCost, PoseCostConfig
|
||||
from curobo.rollout.cost.pose_cost import PoseCost, PoseCostConfig, PoseCostMetric
|
||||
from curobo.rollout.cost.straight_line_cost import StraightLineCost
|
||||
from curobo.rollout.cost.zero_cost import ZeroCost
|
||||
from curobo.rollout.dynamics_model.kinematic_model import KinematicModelState
|
||||
from curobo.rollout.rollout_base import Goal, RolloutMetrics
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import RobotConfig
|
||||
from curobo.types.tensor import T_BValue_float
|
||||
from curobo.types.tensor import T_BValue_float, T_BValue_int
|
||||
from curobo.util.helpers import list_idx_if_not_none
|
||||
from curobo.util.logger import log_info
|
||||
from curobo.util.logger import log_error, log_info
|
||||
from curobo.util.tensor_util import cat_max, cat_sum
|
||||
|
||||
# Local Folder
|
||||
@@ -42,6 +42,8 @@ class ArmReacherMetrics(RolloutMetrics):
|
||||
position_error: Optional[T_BValue_float] = None
|
||||
rotation_error: Optional[T_BValue_float] = None
|
||||
pose_error: Optional[T_BValue_float] = None
|
||||
goalset_index: Optional[T_BValue_int] = None
|
||||
null_space_error: Optional[T_BValue_float] = None
|
||||
|
||||
def __getitem__(self, idx):
|
||||
d_list = [
|
||||
@@ -53,6 +55,8 @@ class ArmReacherMetrics(RolloutMetrics):
|
||||
self.position_error,
|
||||
self.rotation_error,
|
||||
self.pose_error,
|
||||
self.goalset_index,
|
||||
self.null_space_error,
|
||||
]
|
||||
idx_vals = list_idx_if_not_none(d_list, idx)
|
||||
return ArmReacherMetrics(*idx_vals)
|
||||
@@ -65,10 +69,14 @@ class ArmReacherMetrics(RolloutMetrics):
|
||||
constraint=None if self.constraint is None else self.constraint.clone(),
|
||||
feasible=None if self.feasible is None else self.feasible.clone(),
|
||||
state=None if self.state is None else self.state,
|
||||
cspace_error=None if self.cspace_error is None else self.cspace_error,
|
||||
position_error=None if self.position_error is None else self.position_error,
|
||||
rotation_error=None if self.rotation_error is None else self.rotation_error,
|
||||
pose_error=None if self.pose_error is None else self.pose_error,
|
||||
cspace_error=None if self.cspace_error is None else self.cspace_error.clone(),
|
||||
position_error=None if self.position_error is None else self.position_error.clone(),
|
||||
rotation_error=None if self.rotation_error is None else self.rotation_error.clone(),
|
||||
pose_error=None if self.pose_error is None else self.pose_error.clone(),
|
||||
goalset_index=None if self.goalset_index is None else self.goalset_index.clone(),
|
||||
null_space_error=(
|
||||
None if self.null_space_error is None else self.null_space_error.clone()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -254,6 +262,7 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
goal_cost = self.goal_cost.forward(
|
||||
ee_pos_batch, ee_quat_batch, self._goal_buffer
|
||||
)
|
||||
|
||||
cost_list.append(goal_cost)
|
||||
with profiler.record_function("cost/link_poses"):
|
||||
if self._goal_buffer.links_goal_pose is not None and self.cost_cfg.pose_cfg is not None:
|
||||
@@ -296,36 +305,21 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
g_dist,
|
||||
)
|
||||
|
||||
# cost += z_acc
|
||||
cost_list.append(z_acc)
|
||||
# print(self.cost_cfg.zero_jerk_cfg)
|
||||
if (
|
||||
self.cost_cfg.zero_jerk_cfg is not None
|
||||
and self.zero_jerk_cost.enabled
|
||||
# and g_dist is not None
|
||||
):
|
||||
# jerk = self.dynamics_model._aux_matrix @ state_batch.acceleration
|
||||
if self.cost_cfg.zero_jerk_cfg is not None and self.zero_jerk_cost.enabled:
|
||||
z_jerk = self.zero_jerk_cost.forward(
|
||||
state_batch.jerk,
|
||||
g_dist,
|
||||
)
|
||||
cost_list.append(z_jerk)
|
||||
# cost += z_jerk
|
||||
|
||||
if (
|
||||
self.cost_cfg.zero_vel_cfg is not None
|
||||
and self.zero_vel_cost.enabled
|
||||
# and g_dist is not None
|
||||
):
|
||||
if self.cost_cfg.zero_vel_cfg is not None and self.zero_vel_cost.enabled:
|
||||
z_vel = self.zero_vel_cost.forward(
|
||||
state_batch.velocity,
|
||||
g_dist,
|
||||
)
|
||||
# cost += z_vel
|
||||
# print(z_vel.shape)
|
||||
cost_list.append(z_vel)
|
||||
cost = cat_sum(cost_list)
|
||||
# print(cost[:].T)
|
||||
return cost
|
||||
|
||||
def convergence_fn(
|
||||
@@ -350,6 +344,7 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
) = self.pose_convergence.forward_out_distance(
|
||||
state.ee_pos_seq, state.ee_quat_seq, self._goal_buffer
|
||||
)
|
||||
out_metrics.goalset_index = self.pose_convergence.goalset_index_buffer # .clone()
|
||||
if (
|
||||
self._goal_buffer.links_goal_pose is not None
|
||||
and self.convergence_cfg.pose_cfg is not None
|
||||
@@ -389,6 +384,17 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
True,
|
||||
)
|
||||
|
||||
if (
|
||||
self.convergence_cfg.null_space_cfg is not None
|
||||
and self.null_convergence.enabled
|
||||
and self._goal_buffer.batch_retract_state_idx is not None
|
||||
):
|
||||
out_metrics.null_space_error = self.null_convergence.forward_target_idx(
|
||||
self._goal_buffer.retract_state,
|
||||
state.state_seq.position,
|
||||
self._goal_buffer.batch_retract_state_idx,
|
||||
)
|
||||
|
||||
return out_metrics
|
||||
|
||||
def update_params(
|
||||
@@ -420,3 +426,43 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
else:
|
||||
self.dist_cost.disable_cost()
|
||||
self.cspace_convergence.disable_cost()
|
||||
|
||||
def get_pose_costs(self, include_link_pose: bool = False, include_convergence: bool = True):
|
||||
pose_costs = [self.goal_cost]
|
||||
if include_convergence:
|
||||
pose_costs += [self.pose_convergence]
|
||||
if include_link_pose:
|
||||
log_error("Not implemented yet")
|
||||
return pose_costs
|
||||
|
||||
def update_pose_cost_metric(
|
||||
self,
|
||||
metric: PoseCostMetric,
|
||||
):
|
||||
pose_costs = self.get_pose_costs()
|
||||
if metric.hold_partial_pose:
|
||||
if metric.hold_vec_weight is None:
|
||||
log_error("hold_vec_weight is required")
|
||||
[x.hold_partial_pose(metric.hold_vec_weight) for x in pose_costs]
|
||||
if metric.release_partial_pose:
|
||||
[x.release_partial_pose() for x in pose_costs]
|
||||
if metric.reach_partial_pose:
|
||||
if metric.reach_vec_weight is None:
|
||||
log_error("reach_vec_weight is required")
|
||||
[x.reach_partial_pose(metric.reach_vec_weight) for x in pose_costs]
|
||||
if metric.reach_full_pose:
|
||||
[x.reach_full_pose() for x in pose_costs]
|
||||
|
||||
pose_costs = self.get_pose_costs(include_convergence=False)
|
||||
if metric.remove_offset_waypoint:
|
||||
[x.remove_offset_waypoint() for x in pose_costs]
|
||||
|
||||
if metric.offset_position is not None or metric.offset_rotation is not None:
|
||||
[
|
||||
x.update_offset_waypoint(
|
||||
offset_position=metric.offset_position,
|
||||
offset_rotation=metric.offset_rotation,
|
||||
offset_tstep_fraction=metric.offset_tstep_fraction,
|
||||
)
|
||||
for x in pose_costs
|
||||
]
|
||||
|
||||
@@ -257,13 +257,13 @@ class BoundCost(CostBase, BoundCostConfig):
|
||||
return cost
|
||||
|
||||
def update_dt(self, dt: Union[float, torch.Tensor]):
|
||||
# return super().update_dt(dt)
|
||||
if self.cost_type == BoundCostType.BOUNDS_SMOOTH:
|
||||
v_scale = dt / self._dt
|
||||
a_scale = v_scale**2
|
||||
j_scale = v_scale**3
|
||||
self.smooth_weight[1] *= a_scale
|
||||
self.smooth_weight[2] *= j_scale
|
||||
|
||||
return super().update_dt(dt)
|
||||
|
||||
|
||||
|
||||
@@ -8,19 +8,23 @@
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
# Standard Library
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
|
||||
# CuRobo
|
||||
from curobo.curobolib.geom import get_pose_distance, get_pose_distance_backward
|
||||
from curobo.curobolib.geom import PoseError, PoseErrorDistance
|
||||
from curobo.rollout.rollout_base import Goal
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.math import OrientationError, Pose
|
||||
from curobo.util.logger import log_error
|
||||
|
||||
# Local Folder
|
||||
from .cost_base import CostBase, CostConfig
|
||||
@@ -37,7 +41,11 @@ class PoseErrorType(Enum):
|
||||
class PoseCostConfig(CostConfig):
|
||||
cost_type: PoseErrorType = PoseErrorType.BATCH_GOAL
|
||||
use_metric: bool = False
|
||||
project_distance: bool = True
|
||||
run_vec_weight: Optional[List[float]] = None
|
||||
use_projected_distance: bool = True
|
||||
offset_waypoint: List[float] = None
|
||||
offset_tstep_fraction: float = -1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.run_vec_weight is not None:
|
||||
@@ -54,392 +62,85 @@ class PoseCostConfig(CostConfig):
|
||||
self.vec_convergence = torch.zeros(
|
||||
2, device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
if self.offset_waypoint is None:
|
||||
self.offset_waypoint = [0, 0, 0, 0, 0, 0]
|
||||
if self.run_weight is None:
|
||||
self.run_weight = 1
|
||||
self.offset_waypoint = self.tensor_args.to_device(self.offset_waypoint)
|
||||
if isinstance(self.offset_tstep_fraction, float):
|
||||
self.offset_tstep_fraction = self.tensor_args.to_device([self.offset_tstep_fraction])
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def backward_PoseError_jit(grad_g_dist, grad_out_distance, weight, g_vec):
|
||||
grad_vec = grad_g_dist + (grad_out_distance * weight)
|
||||
grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec
|
||||
return grad
|
||||
@dataclass
|
||||
class PoseCostMetric:
|
||||
hold_partial_pose: bool = False
|
||||
release_partial_pose: bool = False
|
||||
hold_vec_weight: Optional[torch.Tensor] = None
|
||||
reach_partial_pose: bool = False
|
||||
reach_full_pose: bool = False
|
||||
reach_vec_weight: Optional[torch.Tensor] = None
|
||||
offset_position: Optional[torch.Tensor] = None
|
||||
offset_rotation: Optional[torch.Tensor] = None
|
||||
offset_tstep_fraction: float = -1.0
|
||||
remove_offset_waypoint: bool = False
|
||||
|
||||
def clone(self):
|
||||
|
||||
# full method:
|
||||
@torch.jit.script
|
||||
def backward_full_PoseError_jit(
|
||||
grad_out_distance, grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
|
||||
):
|
||||
p_grad = (grad_g_dist + (grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
|
||||
q_grad = (grad_r_err + (grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
|
||||
# p_grad = ((grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
|
||||
# q_grad = ((grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
|
||||
|
||||
return p_grad, q_grad
|
||||
|
||||
|
||||
class PoseErrorDistance(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
current_position,
|
||||
goal_position,
|
||||
current_quat,
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
batch_pose_idx,
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
out_p_grad,
|
||||
out_q_grad,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode=PoseErrorType.BATCH_GOAL.value,
|
||||
num_goals=1,
|
||||
use_metric=False,
|
||||
):
|
||||
# out_distance = current_position[..., 0].detach().clone() * 0.0
|
||||
# out_position_distance = out_distance.detach().clone()
|
||||
# out_rotation_distance = out_distance.detach().clone()
|
||||
# out_vec = (
|
||||
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
|
||||
# * 0.0
|
||||
# )
|
||||
# out_idx = out_distance.clone().to(dtype=torch.long)
|
||||
|
||||
(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
) = get_pose_distance(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
current_position.contiguous(),
|
||||
goal_position,
|
||||
current_quat.contiguous(),
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
batch_pose_idx,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode,
|
||||
num_goals,
|
||||
current_position.requires_grad,
|
||||
True,
|
||||
use_metric,
|
||||
)
|
||||
ctx.save_for_backward(out_p_vec, out_r_vec, weight, out_p_grad, out_q_grad)
|
||||
return out_distance, out_position_distance, out_rotation_distance, out_idx # .view(-1,1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out_distance, grad_g_dist, grad_r_err, grad_out_idx):
|
||||
(g_vec_p, g_vec_q, weight, out_grad_p, out_grad_q) = ctx.saved_tensors
|
||||
pos_grad = None
|
||||
quat_grad = None
|
||||
batch_size = g_vec_p.shape[0] * g_vec_p.shape[1]
|
||||
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
|
||||
pos_grad, quat_grad = get_pose_distance_backward(
|
||||
out_grad_p,
|
||||
out_grad_q,
|
||||
grad_out_distance.contiguous(),
|
||||
grad_g_dist.contiguous(),
|
||||
grad_r_err.contiguous(),
|
||||
weight,
|
||||
g_vec_p,
|
||||
g_vec_q,
|
||||
batch_size,
|
||||
use_distance=True,
|
||||
)
|
||||
# pos_grad, quat_grad = backward_full_PoseError_jit(
|
||||
# grad_out_distance,
|
||||
# grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
|
||||
# )
|
||||
elif ctx.needs_input_grad[0]:
|
||||
pos_grad = backward_PoseError_jit(grad_g_dist, grad_out_distance, p_w, g_vec_p)
|
||||
# grad_vec = grad_g_dist + (grad_out_distance * weight[1])
|
||||
# pos_grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec[..., 4:]
|
||||
elif ctx.needs_input_grad[2]:
|
||||
quat_grad = backward_PoseError_jit(grad_r_err, grad_out_distance, q_w, g_vec_q)
|
||||
# grad_vec = grad_r_err + (grad_out_distance * weight[0])
|
||||
# quat_grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec[..., :4]
|
||||
return (
|
||||
pos_grad,
|
||||
None,
|
||||
quat_grad,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
return PoseCostMetric(
|
||||
hold_partial_pose=self.hold_partial_pose,
|
||||
release_partial_pose=self.release_partial_pose,
|
||||
hold_vec_weight=None if self.hold_vec_weight is None else self.hold_vec_weight.clone(),
|
||||
reach_partial_pose=self.reach_partial_pose,
|
||||
reach_full_pose=self.reach_full_pose,
|
||||
reach_vec_weight=(
|
||||
None if self.reach_vec_weight is None else self.reach_vec_weight.clone()
|
||||
),
|
||||
offset_position=None if self.offset_position is None else self.offset_position.clone(),
|
||||
offset_rotation=None if self.offset_rotation is None else self.offset_rotation.clone(),
|
||||
offset_tstep_fraction=self.offset_tstep_fraction,
|
||||
remove_offset_waypoint=self.remove_offset_waypoint,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_grasp_approach_metric(
|
||||
cls,
|
||||
offset_position: float = 0.5,
|
||||
linear_axis: int = 2,
|
||||
tstep_fraction: float = 0.6,
|
||||
tensor_args: TensorDeviceType = TensorDeviceType(),
|
||||
) -> PoseCostMetric:
|
||||
"""Enables moving to a pregrasp and then locked orientation movement to final grasp.
|
||||
|
||||
class PoseLoss(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
current_position,
|
||||
goal_position,
|
||||
current_quat,
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
batch_pose_idx,
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
out_p_grad,
|
||||
out_q_grad,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode=PoseErrorType.BATCH_GOAL.value,
|
||||
num_goals=1,
|
||||
use_metric=False,
|
||||
):
|
||||
# out_distance = current_position[..., 0].detach().clone() * 0.0
|
||||
# out_position_distance = out_distance.detach().clone()
|
||||
# out_rotation_distance = out_distance.detach().clone()
|
||||
# out_vec = (
|
||||
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
|
||||
# * 0.0
|
||||
# )
|
||||
# out_idx = out_distance.clone().to(dtype=torch.long)
|
||||
Since this is added as a cost, the trajectory will not reach the exact offset, instead it
|
||||
will try to take a blended path to the final grasp without stopping at the offset.
|
||||
|
||||
(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
) = get_pose_distance(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
current_position.contiguous(),
|
||||
goal_position,
|
||||
current_quat.contiguous(),
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
batch_pose_idx,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode,
|
||||
num_goals,
|
||||
current_position.requires_grad,
|
||||
False,
|
||||
use_metric,
|
||||
)
|
||||
ctx.save_for_backward(out_p_vec, out_r_vec)
|
||||
return out_distance
|
||||
Args:
|
||||
offset_position: offset in meters.
|
||||
linear_axis: specifies the x y or z axis.
|
||||
tstep_fraction: specifies the timestep fraction to start activating this constraint.
|
||||
tensor_args: cuda device.
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out_distance): # , grad_g_dist, grad_r_err, grad_out_idx):
|
||||
pos_grad = None
|
||||
quat_grad = None
|
||||
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
pos_grad = g_vec_p * grad_out_distance.unsqueeze(1)
|
||||
quat_grad = g_vec_q * grad_out_distance.unsqueeze(1)
|
||||
pos_grad = pos_grad.unsqueeze(-2)
|
||||
quat_grad = quat_grad.unsqueeze(-2)
|
||||
elif ctx.needs_input_grad[0]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
|
||||
pos_grad = g_vec_p * grad_out_distance.unsqueeze(1)
|
||||
pos_grad = pos_grad.unsqueeze(-2)
|
||||
elif ctx.needs_input_grad[2]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
|
||||
quat_grad = g_vec_q * grad_out_distance.unsqueeze(1)
|
||||
quat_grad = quat_grad.unsqueeze(-2)
|
||||
|
||||
return (
|
||||
pos_grad,
|
||||
None,
|
||||
quat_grad,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Returns:
|
||||
cost metric.
|
||||
"""
|
||||
hold_vec_weight = tensor_args.to_device([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
hold_vec_weight[3 + linear_axis] = 0.0
|
||||
offset_position_vec = tensor_args.to_device([0.0, 0.0, 0.0])
|
||||
offset_position_vec[linear_axis] = offset_position
|
||||
return cls(
|
||||
hold_partial_pose=True,
|
||||
hold_vec_weight=hold_vec_weight,
|
||||
offset_position=offset_position_vec,
|
||||
offset_tstep_fraction=tstep_fraction,
|
||||
)
|
||||
|
||||
|
||||
class PoseError(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
current_position,
|
||||
goal_position,
|
||||
current_quat,
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
batch_pose_idx,
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
out_p_grad,
|
||||
out_q_grad,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode=PoseErrorType.BATCH_GOAL.value,
|
||||
num_goals=1,
|
||||
use_metric=False,
|
||||
):
|
||||
# out_distance = current_position[..., 0].detach().clone() * 0.0
|
||||
# out_position_distance = out_distance.detach().clone()
|
||||
# out_rotation_distance = out_distance.detach().clone()
|
||||
# out_vec = (
|
||||
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
|
||||
# * 0.0
|
||||
# )
|
||||
# out_idx = out_distance.clone().to(dtype=torch.long)
|
||||
|
||||
(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
) = get_pose_distance(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
current_position.contiguous(),
|
||||
goal_position,
|
||||
current_quat.contiguous(),
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
batch_pose_idx,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode,
|
||||
num_goals,
|
||||
current_position.requires_grad,
|
||||
False,
|
||||
use_metric,
|
||||
)
|
||||
ctx.save_for_backward(out_p_vec, out_r_vec)
|
||||
return out_distance
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out_distance): # , grad_g_dist, grad_r_err, grad_out_idx):
|
||||
pos_grad = None
|
||||
quat_grad = None
|
||||
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
|
||||
pos_grad = g_vec_p
|
||||
quat_grad = g_vec_q
|
||||
elif ctx.needs_input_grad[0]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
|
||||
pos_grad = g_vec_p
|
||||
elif ctx.needs_input_grad[2]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
|
||||
quat_grad = g_vec_q
|
||||
return (
|
||||
pos_grad,
|
||||
None,
|
||||
quat_grad,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
@classmethod
|
||||
def reset_metric(cls) -> PoseCostMetric:
|
||||
return PoseCostMetric(
|
||||
remove_offset_waypoint=True,
|
||||
reach_full_pose=True,
|
||||
release_partial_pose=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -449,13 +150,88 @@ class PoseCost(CostBase, PoseCostConfig):
|
||||
CostBase.__init__(self)
|
||||
self.rot_weight = self.vec_weight[0:3]
|
||||
self.pos_weight = self.vec_weight[3:6]
|
||||
|
||||
self._vec_convergence = self.tensor_args.to_device(self.vec_convergence)
|
||||
self._batch_size = 0
|
||||
self._horizon = 0
|
||||
|
||||
def update_metric(self, metric: PoseCostMetric):
|
||||
if metric.hold_partial_pose:
|
||||
if metric.hold_vec_weight is None:
|
||||
log_error("hold_vec_weight is required")
|
||||
self.hold_partial_pose(metric.hold_vec_weight)
|
||||
if metric.release_partial_pose:
|
||||
self.release_partial_pose()
|
||||
if metric.reach_partial_pose:
|
||||
if metric.reach_vec_weight is None:
|
||||
log_error("reach_vec_weight is required")
|
||||
self.reach_partial_pose(metric.reach_vec_weight)
|
||||
if metric.reach_full_pose:
|
||||
self.reach_full_pose()
|
||||
|
||||
if metric.remove_offset_waypoint:
|
||||
self.remove_offset_waypoint()
|
||||
|
||||
if metric.offset_position is not None or metric.offset_rotation is not None:
|
||||
self.update_offset_waypoint(
|
||||
offset_position=self.offset_position,
|
||||
offset_rotation=self.offset_rotation,
|
||||
offset_tstep_fraction=self.offset_tstep_fraction,
|
||||
)
|
||||
|
||||
def hold_partial_pose(self, run_vec_weight: torch.Tensor):
|
||||
|
||||
self.run_vec_weight.copy_(run_vec_weight)
|
||||
|
||||
def release_partial_pose(self):
|
||||
self.run_vec_weight[:] = 0.0
|
||||
|
||||
def reach_partial_pose(self, vec_weight: torch.Tensor):
|
||||
self.vec_weight[:] = vec_weight
|
||||
|
||||
def reach_full_pose(self):
|
||||
self.vec_weight[:] = 1.0
|
||||
|
||||
def update_offset_waypoint(
|
||||
self,
|
||||
offset_position: Optional[torch.Tensor] = None,
|
||||
offset_rotation: Optional[torch.Tensor] = None,
|
||||
offset_tstep_fraction: float = 0.75,
|
||||
):
|
||||
if offset_position is not None:
|
||||
self.offset_waypoint[3:].copy_(offset_position)
|
||||
if offset_rotation is not None:
|
||||
self.offset_waypoint[:3].copy_(offset_rotation)
|
||||
self.offset_tstep_fraction[:] = offset_tstep_fraction
|
||||
if self._horizon <= 0:
|
||||
print(self.weight)
|
||||
log_error(
|
||||
"Updating offset waypoint is only possible after initializing motion gen"
|
||||
+ " run motion_gen.warmup() before adding offset_waypoint"
|
||||
)
|
||||
self.update_run_weight(run_tstep_fraction=offset_tstep_fraction)
|
||||
|
||||
def remove_offset_waypoint(self):
|
||||
self.offset_tstep_fraction[:] = -1.0
|
||||
self.update_run_weight()
|
||||
|
||||
def update_run_weight(
|
||||
self, run_tstep_fraction: float = 0.0, run_weight: Optional[float] = None
|
||||
):
|
||||
if self._horizon == 1:
|
||||
return
|
||||
|
||||
if run_weight is None:
|
||||
run_weight = self.run_weight
|
||||
|
||||
active_steps = math.floor(self._horizon * run_tstep_fraction)
|
||||
self._run_weight_vec[:, :active_steps] = 0
|
||||
self._run_weight_vec[:, active_steps:-1] = run_weight
|
||||
|
||||
def update_batch_size(self, batch_size, horizon):
|
||||
if batch_size != self._batch_size or horizon != self._horizon:
|
||||
# print(self.weight)
|
||||
# print(batch_size, horizon, self._batch_size, self._horizon)
|
||||
|
||||
# batch_size = b*h
|
||||
self.out_distance = torch.zeros(
|
||||
(batch_size, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
@@ -493,12 +269,16 @@ class PoseCost(CostBase, PoseCostConfig):
|
||||
self._run_weight_vec = torch.ones(
|
||||
(1, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
if self.terminal and self.run_weight is not None:
|
||||
self._run_weight_vec[:, :-1] *= self.run_weight
|
||||
if self.terminal and self.run_weight is not None and horizon > 1:
|
||||
self._run_weight_vec[:, :-1] = self.run_weight
|
||||
|
||||
self._batch_size = batch_size
|
||||
self._horizon = horizon
|
||||
|
||||
@property
|
||||
def goalset_index_buffer(self):
|
||||
return self.out_idx
|
||||
|
||||
def _forward_goal_distribution(self, ee_pos_batch, ee_rot_batch, ee_goal_pos, ee_goal_rot):
|
||||
ee_goal_pos = ee_goal_pos.unsqueeze(1)
|
||||
ee_goal_pos = ee_goal_pos.unsqueeze(1)
|
||||
@@ -563,13 +343,13 @@ class PoseCost(CostBase, PoseCostConfig):
|
||||
def _update_cost_type(self, ee_goal_pos, ee_pos_batch, num_goals):
|
||||
d_g = len(ee_goal_pos.shape)
|
||||
b_sze = ee_goal_pos.shape[0]
|
||||
if d_g == 2 and b_sze == 1: # 1, 3
|
||||
if d_g == 2 and b_sze == 1: # [1, 3]
|
||||
self.cost_type = PoseErrorType.SINGLE_GOAL
|
||||
elif d_g == 2 and b_sze == ee_pos_batch.shape[0]: # b, 3
|
||||
elif d_g == 2 and b_sze > 1: # [b, 3]
|
||||
self.cost_type = PoseErrorType.BATCH_GOAL
|
||||
elif d_g == 3:
|
||||
elif d_g == 3 and b_sze == 1: # [1, goalset, 3]
|
||||
self.cost_type = PoseErrorType.GOALSET
|
||||
elif len(ee_goal_pos.shape) == 4 and b_sze == ee_pos_bath.shape[0]:
|
||||
elif d_g == 3 and b_sze > 1: # [b, goalset,3]
|
||||
self.cost_type = PoseErrorType.BATCH_GOALSET
|
||||
|
||||
def forward_out_distance(
|
||||
@@ -599,6 +379,8 @@ class PoseCost(CostBase, PoseCostConfig):
|
||||
self._vec_convergence,
|
||||
self._run_weight_vec,
|
||||
self.run_vec_weight,
|
||||
self.offset_waypoint,
|
||||
self.offset_tstep_fraction,
|
||||
goal.batch_pose_idx,
|
||||
self.out_distance,
|
||||
self.out_position_distance,
|
||||
@@ -613,7 +395,9 @@ class PoseCost(CostBase, PoseCostConfig):
|
||||
self.cost_type.value,
|
||||
num_goals,
|
||||
self.use_metric,
|
||||
self.project_distance,
|
||||
)
|
||||
# print(self.out_idx.shape, self.out_idx[:,-1])
|
||||
# print(goal.batch_pose_idx.shape)
|
||||
cost = distance # .view(b, h)#.clone()
|
||||
r_err = r_err # .view(b, h)
|
||||
@@ -632,65 +416,46 @@ class PoseCost(CostBase, PoseCostConfig):
|
||||
ee_goal_rot = goal_pose.quaternion
|
||||
num_goals = goal_pose.n_goalset
|
||||
self._update_cost_type(ee_goal_pos, ee_pos_batch, num_goals)
|
||||
|
||||
# print(self.cost_type)
|
||||
b, h, _ = ee_pos_batch.shape
|
||||
self.update_batch_size(b, h)
|
||||
# return self.out_distance
|
||||
# print(b,h, ee_goal_pos.shape)
|
||||
if self.return_loss:
|
||||
distance = PoseLoss.apply(
|
||||
ee_pos_batch,
|
||||
ee_goal_pos,
|
||||
ee_rot_batch, # .view(-1, 4).contiguous(),
|
||||
ee_goal_rot,
|
||||
self.vec_weight,
|
||||
self.weight,
|
||||
self._vec_convergence,
|
||||
self._run_weight_vec,
|
||||
self.run_vec_weight,
|
||||
goal.batch_pose_idx,
|
||||
self.out_distance,
|
||||
self.out_position_distance,
|
||||
self.out_rotation_distance,
|
||||
self.out_p_vec,
|
||||
self.out_q_vec,
|
||||
self.out_idx,
|
||||
self.out_p_grad,
|
||||
self.out_q_grad,
|
||||
b,
|
||||
h,
|
||||
self.cost_type.value,
|
||||
num_goals,
|
||||
self.use_metric,
|
||||
)
|
||||
else:
|
||||
distance = PoseError.apply(
|
||||
ee_pos_batch,
|
||||
ee_goal_pos,
|
||||
ee_rot_batch, # .view(-1, 4).contiguous(),
|
||||
ee_goal_rot,
|
||||
self.vec_weight,
|
||||
self.weight,
|
||||
self._vec_convergence,
|
||||
self._run_weight_vec,
|
||||
self.run_vec_weight,
|
||||
goal.batch_pose_idx,
|
||||
self.out_distance,
|
||||
self.out_position_distance,
|
||||
self.out_rotation_distance,
|
||||
self.out_p_vec,
|
||||
self.out_q_vec,
|
||||
self.out_idx,
|
||||
self.out_p_grad,
|
||||
self.out_q_grad,
|
||||
b,
|
||||
h,
|
||||
self.cost_type.value,
|
||||
num_goals,
|
||||
self.use_metric,
|
||||
)
|
||||
|
||||
distance = PoseError.apply(
|
||||
ee_pos_batch,
|
||||
ee_goal_pos,
|
||||
ee_rot_batch, # .view(-1, 4).contiguous(),
|
||||
ee_goal_rot,
|
||||
self.vec_weight,
|
||||
self.weight,
|
||||
self._vec_convergence,
|
||||
self._run_weight_vec,
|
||||
self.run_vec_weight,
|
||||
self.offset_waypoint,
|
||||
self.offset_tstep_fraction,
|
||||
goal.batch_pose_idx,
|
||||
self.out_distance,
|
||||
self.out_position_distance,
|
||||
self.out_rotation_distance,
|
||||
self.out_p_vec,
|
||||
self.out_q_vec,
|
||||
self.out_idx,
|
||||
self.out_p_grad,
|
||||
self.out_q_grad,
|
||||
b,
|
||||
h,
|
||||
self.cost_type.value,
|
||||
num_goals,
|
||||
self.use_metric,
|
||||
self.project_distance,
|
||||
self.return_loss,
|
||||
)
|
||||
|
||||
cost = distance
|
||||
# if link_name is None and cost.shape[0]==8:
|
||||
# print(ee_pos_batch[...,-1].squeeze())
|
||||
# print(cost.shape)
|
||||
return cost
|
||||
|
||||
def forward_pose(
|
||||
@@ -708,56 +473,34 @@ class PoseCost(CostBase, PoseCostConfig):
|
||||
b = query_pose.position.shape[0]
|
||||
h = query_pose.position.shape[1]
|
||||
num_goals = 1
|
||||
if self.return_loss:
|
||||
distance = PoseLoss.apply(
|
||||
query_pose.position.unsqueeze(1),
|
||||
ee_goal_pos,
|
||||
query_pose.quaternion.unsqueeze(1),
|
||||
ee_goal_quat,
|
||||
self.vec_weight,
|
||||
self.weight,
|
||||
self._vec_convergence,
|
||||
self._run_weight_vec,
|
||||
self.run_vec_weight,
|
||||
batch_pose_idx,
|
||||
self.out_distance,
|
||||
self.out_position_distance,
|
||||
self.out_rotation_distance,
|
||||
self.out_p_vec,
|
||||
self.out_q_vec,
|
||||
self.out_idx,
|
||||
self.out_p_grad,
|
||||
self.out_q_grad,
|
||||
b,
|
||||
h,
|
||||
self.cost_type.value,
|
||||
num_goals,
|
||||
self.use_metric,
|
||||
)
|
||||
else:
|
||||
distance = PoseError.apply(
|
||||
query_pose.position.unsqueeze(1),
|
||||
ee_goal_pos,
|
||||
query_pose.quaternion.unsqueeze(1),
|
||||
ee_goal_quat,
|
||||
self.vec_weight,
|
||||
self.weight,
|
||||
self._vec_convergence,
|
||||
self._run_weight_vec,
|
||||
self.run_vec_weight,
|
||||
batch_pose_idx,
|
||||
self.out_distance,
|
||||
self.out_position_distance,
|
||||
self.out_rotation_distance,
|
||||
self.out_p_vec,
|
||||
self.out_q_vec,
|
||||
self.out_idx,
|
||||
self.out_p_grad,
|
||||
self.out_q_grad,
|
||||
b,
|
||||
h,
|
||||
self.cost_type.value,
|
||||
num_goals,
|
||||
self.use_metric,
|
||||
)
|
||||
|
||||
distance = PoseError.apply(
|
||||
query_pose.position.unsqueeze(1),
|
||||
ee_goal_pos,
|
||||
query_pose.quaternion.unsqueeze(1),
|
||||
ee_goal_quat,
|
||||
self.vec_weight,
|
||||
self.weight,
|
||||
self._vec_convergence,
|
||||
self._run_weight_vec,
|
||||
self.run_vec_weight,
|
||||
self.offset_waypoint,
|
||||
self.offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
self.out_distance,
|
||||
self.out_position_distance,
|
||||
self.out_rotation_distance,
|
||||
self.out_p_vec,
|
||||
self.out_q_vec,
|
||||
self.out_idx,
|
||||
self.out_p_grad,
|
||||
self.out_q_grad,
|
||||
b,
|
||||
h,
|
||||
self.cost_type.value,
|
||||
num_goals,
|
||||
self.use_metric,
|
||||
self.project_distance,
|
||||
self.return_loss,
|
||||
)
|
||||
return distance
|
||||
|
||||
@@ -68,9 +68,14 @@ class StopCost(CostBase, StopCostConfig):
|
||||
self.max_vel = (sum_matrix @ delta_vel).unsqueeze(-1)
|
||||
|
||||
def forward(self, vels):
|
||||
vel_abs = torch.abs(vels)
|
||||
vel_abs = torch.nn.functional.relu(vel_abs - self.max_vel)
|
||||
|
||||
cost = self.weight * (torch.sum(vel_abs**2, dim=-1))
|
||||
|
||||
cost = velocity_cost(vels, self.weight, self.max_vel)
|
||||
return cost
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def velocity_cost(vels, weight, max_vel):
|
||||
vel_abs = torch.abs(vels)
|
||||
vel_abs = torch.nn.functional.relu(vel_abs - max_vel[: vels.shape[1]])
|
||||
cost = weight * (torch.sum(vel_abs**2, dim=-1))
|
||||
|
||||
return cost
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
# CuRobo
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.math import Pose
|
||||
from curobo.types.robot import CSpaceConfig, State, JointState
|
||||
from curobo.types.robot import CSpaceConfig, State
|
||||
from curobo.types.tensor import (
|
||||
T_BDOF,
|
||||
T_DOF,
|
||||
@@ -33,6 +33,7 @@ from curobo.types.tensor import (
|
||||
T_BValue_float,
|
||||
)
|
||||
from curobo.util.helpers import list_idx_if_not_none
|
||||
from curobo.util.logger import log_info
|
||||
from curobo.util.sample_lib import HaltonGenerator
|
||||
from curobo.util.tensor_util import copy_tensor
|
||||
|
||||
@@ -235,6 +236,7 @@ class Goal(Sequence):
|
||||
batch_retract_state_idx=self.batch_retract_state_idx,
|
||||
batch_goal_state_idx=self.batch_goal_state_idx,
|
||||
links_goal_pose=self.links_goal_pose,
|
||||
n_goalset=self.n_goalset,
|
||||
)
|
||||
|
||||
def _tensor_repeat_seeds(self, tensor, num_seeds):
|
||||
@@ -353,7 +355,7 @@ class Goal(Sequence):
|
||||
|
||||
def _copy_tensor(self, ref_buffer, buffer):
|
||||
if buffer is not None:
|
||||
if ref_buffer is not None:
|
||||
if ref_buffer is not None and buffer.shape == ref_buffer.shape:
|
||||
if not copy_tensor(buffer, ref_buffer):
|
||||
ref_buffer = buffer.clone()
|
||||
else:
|
||||
@@ -553,6 +555,10 @@ class RolloutBase:
|
||||
self._rollout_constraint_cuda_graph_init = False
|
||||
if self.cu_rollout_constraint_graph is not None:
|
||||
self.cu_rollout_constraint_graph.reset()
|
||||
self.reset_shape()
|
||||
|
||||
def reset_shape(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_action_from_state(self, state: State):
|
||||
|
||||
Reference in New Issue
Block a user