constrained planning, robot segmentation

This commit is contained in:
Balakumar Sundaralingam
2024-02-22 21:45:47 -08:00
parent 88eac64edc
commit bafdf80c05
102 changed files with 12440 additions and 8112 deletions

View File

@@ -257,13 +257,13 @@ class BoundCost(CostBase, BoundCostConfig):
return cost
def update_dt(self, dt: Union[float, torch.Tensor]):
# return super().update_dt(dt)
if self.cost_type == BoundCostType.BOUNDS_SMOOTH:
v_scale = dt / self._dt
a_scale = v_scale**2
j_scale = v_scale**3
self.smooth_weight[1] *= a_scale
self.smooth_weight[2] *= j_scale
return super().update_dt(dt)

View File

@@ -8,19 +8,23 @@
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
from __future__ import annotations
# Standard Library
import math
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
# Third Party
import torch
from torch.autograd import Function
# CuRobo
from curobo.curobolib.geom import get_pose_distance, get_pose_distance_backward
from curobo.curobolib.geom import PoseError, PoseErrorDistance
from curobo.rollout.rollout_base import Goal
from curobo.types.base import TensorDeviceType
from curobo.types.math import OrientationError, Pose
from curobo.util.logger import log_error
# Local Folder
from .cost_base import CostBase, CostConfig
@@ -37,7 +41,11 @@ class PoseErrorType(Enum):
class PoseCostConfig(CostConfig):
cost_type: PoseErrorType = PoseErrorType.BATCH_GOAL
use_metric: bool = False
project_distance: bool = True
run_vec_weight: Optional[List[float]] = None
use_projected_distance: bool = True
offset_waypoint: List[float] = None
offset_tstep_fraction: float = -1.0
def __post_init__(self):
if self.run_vec_weight is not None:
@@ -54,392 +62,85 @@ class PoseCostConfig(CostConfig):
self.vec_convergence = torch.zeros(
2, device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
if self.offset_waypoint is None:
self.offset_waypoint = [0, 0, 0, 0, 0, 0]
if self.run_weight is None:
self.run_weight = 1
self.offset_waypoint = self.tensor_args.to_device(self.offset_waypoint)
if isinstance(self.offset_tstep_fraction, float):
self.offset_tstep_fraction = self.tensor_args.to_device([self.offset_tstep_fraction])
return super().__post_init__()
@torch.jit.script
def backward_PoseError_jit(grad_g_dist, grad_out_distance, weight, g_vec):
grad_vec = grad_g_dist + (grad_out_distance * weight)
grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec
return grad
@dataclass
class PoseCostMetric:
hold_partial_pose: bool = False
release_partial_pose: bool = False
hold_vec_weight: Optional[torch.Tensor] = None
reach_partial_pose: bool = False
reach_full_pose: bool = False
reach_vec_weight: Optional[torch.Tensor] = None
offset_position: Optional[torch.Tensor] = None
offset_rotation: Optional[torch.Tensor] = None
offset_tstep_fraction: float = -1.0
remove_offset_waypoint: bool = False
def clone(self):
# full method:
@torch.jit.script
def backward_full_PoseError_jit(
grad_out_distance, grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
):
p_grad = (grad_g_dist + (grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
q_grad = (grad_r_err + (grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
# p_grad = ((grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
# q_grad = ((grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
return p_grad, q_grad
class PoseErrorDistance(Function):
@staticmethod
def forward(
ctx,
current_position,
goal_position,
current_quat,
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
out_p_grad,
out_q_grad,
batch_size,
horizon,
mode=PoseErrorType.BATCH_GOAL.value,
num_goals=1,
use_metric=False,
):
# out_distance = current_position[..., 0].detach().clone() * 0.0
# out_position_distance = out_distance.detach().clone()
# out_rotation_distance = out_distance.detach().clone()
# out_vec = (
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
# * 0.0
# )
# out_idx = out_distance.clone().to(dtype=torch.long)
(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
) = get_pose_distance(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
current_position.contiguous(),
goal_position,
current_quat.contiguous(),
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
batch_size,
horizon,
mode,
num_goals,
current_position.requires_grad,
True,
use_metric,
)
ctx.save_for_backward(out_p_vec, out_r_vec, weight, out_p_grad, out_q_grad)
return out_distance, out_position_distance, out_rotation_distance, out_idx # .view(-1,1)
@staticmethod
def backward(ctx, grad_out_distance, grad_g_dist, grad_r_err, grad_out_idx):
(g_vec_p, g_vec_q, weight, out_grad_p, out_grad_q) = ctx.saved_tensors
pos_grad = None
quat_grad = None
batch_size = g_vec_p.shape[0] * g_vec_p.shape[1]
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
pos_grad, quat_grad = get_pose_distance_backward(
out_grad_p,
out_grad_q,
grad_out_distance.contiguous(),
grad_g_dist.contiguous(),
grad_r_err.contiguous(),
weight,
g_vec_p,
g_vec_q,
batch_size,
use_distance=True,
)
# pos_grad, quat_grad = backward_full_PoseError_jit(
# grad_out_distance,
# grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
# )
elif ctx.needs_input_grad[0]:
pos_grad = backward_PoseError_jit(grad_g_dist, grad_out_distance, p_w, g_vec_p)
# grad_vec = grad_g_dist + (grad_out_distance * weight[1])
# pos_grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec[..., 4:]
elif ctx.needs_input_grad[2]:
quat_grad = backward_PoseError_jit(grad_r_err, grad_out_distance, q_w, g_vec_q)
# grad_vec = grad_r_err + (grad_out_distance * weight[0])
# quat_grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec[..., :4]
return (
pos_grad,
None,
quat_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
return PoseCostMetric(
hold_partial_pose=self.hold_partial_pose,
release_partial_pose=self.release_partial_pose,
hold_vec_weight=None if self.hold_vec_weight is None else self.hold_vec_weight.clone(),
reach_partial_pose=self.reach_partial_pose,
reach_full_pose=self.reach_full_pose,
reach_vec_weight=(
None if self.reach_vec_weight is None else self.reach_vec_weight.clone()
),
offset_position=None if self.offset_position is None else self.offset_position.clone(),
offset_rotation=None if self.offset_rotation is None else self.offset_rotation.clone(),
offset_tstep_fraction=self.offset_tstep_fraction,
remove_offset_waypoint=self.remove_offset_waypoint,
)
@classmethod
def create_grasp_approach_metric(
cls,
offset_position: float = 0.5,
linear_axis: int = 2,
tstep_fraction: float = 0.6,
tensor_args: TensorDeviceType = TensorDeviceType(),
) -> PoseCostMetric:
"""Enables moving to a pregrasp and then locked orientation movement to final grasp.
class PoseLoss(Function):
@staticmethod
def forward(
ctx,
current_position,
goal_position,
current_quat,
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
out_p_grad,
out_q_grad,
batch_size,
horizon,
mode=PoseErrorType.BATCH_GOAL.value,
num_goals=1,
use_metric=False,
):
# out_distance = current_position[..., 0].detach().clone() * 0.0
# out_position_distance = out_distance.detach().clone()
# out_rotation_distance = out_distance.detach().clone()
# out_vec = (
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
# * 0.0
# )
# out_idx = out_distance.clone().to(dtype=torch.long)
Since this is added as a cost, the trajectory will not reach the exact offset, instead it
will try to take a blended path to the final grasp without stopping at the offset.
(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
) = get_pose_distance(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
current_position.contiguous(),
goal_position,
current_quat.contiguous(),
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
batch_size,
horizon,
mode,
num_goals,
current_position.requires_grad,
False,
use_metric,
)
ctx.save_for_backward(out_p_vec, out_r_vec)
return out_distance
Args:
offset_position: offset in meters.
linear_axis: specifies the x y or z axis.
tstep_fraction: specifies the timestep fraction to start activating this constraint.
tensor_args: cuda device.
@staticmethod
def backward(ctx, grad_out_distance): # , grad_g_dist, grad_r_err, grad_out_idx):
pos_grad = None
quat_grad = None
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
pos_grad = g_vec_p * grad_out_distance.unsqueeze(1)
quat_grad = g_vec_q * grad_out_distance.unsqueeze(1)
pos_grad = pos_grad.unsqueeze(-2)
quat_grad = quat_grad.unsqueeze(-2)
elif ctx.needs_input_grad[0]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
pos_grad = g_vec_p * grad_out_distance.unsqueeze(1)
pos_grad = pos_grad.unsqueeze(-2)
elif ctx.needs_input_grad[2]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
quat_grad = g_vec_q * grad_out_distance.unsqueeze(1)
quat_grad = quat_grad.unsqueeze(-2)
return (
pos_grad,
None,
quat_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
Returns:
cost metric.
"""
hold_vec_weight = tensor_args.to_device([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
hold_vec_weight[3 + linear_axis] = 0.0
offset_position_vec = tensor_args.to_device([0.0, 0.0, 0.0])
offset_position_vec[linear_axis] = offset_position
return cls(
hold_partial_pose=True,
hold_vec_weight=hold_vec_weight,
offset_position=offset_position_vec,
offset_tstep_fraction=tstep_fraction,
)
class PoseError(Function):
@staticmethod
def forward(
ctx,
current_position,
goal_position,
current_quat,
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
out_p_grad,
out_q_grad,
batch_size,
horizon,
mode=PoseErrorType.BATCH_GOAL.value,
num_goals=1,
use_metric=False,
):
# out_distance = current_position[..., 0].detach().clone() * 0.0
# out_position_distance = out_distance.detach().clone()
# out_rotation_distance = out_distance.detach().clone()
# out_vec = (
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
# * 0.0
# )
# out_idx = out_distance.clone().to(dtype=torch.long)
(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
) = get_pose_distance(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
current_position.contiguous(),
goal_position,
current_quat.contiguous(),
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
batch_pose_idx,
batch_size,
horizon,
mode,
num_goals,
current_position.requires_grad,
False,
use_metric,
)
ctx.save_for_backward(out_p_vec, out_r_vec)
return out_distance
@staticmethod
def backward(ctx, grad_out_distance): # , grad_g_dist, grad_r_err, grad_out_idx):
pos_grad = None
quat_grad = None
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
pos_grad = g_vec_p
quat_grad = g_vec_q
elif ctx.needs_input_grad[0]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
pos_grad = g_vec_p
elif ctx.needs_input_grad[2]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
quat_grad = g_vec_q
return (
pos_grad,
None,
quat_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
@classmethod
def reset_metric(cls) -> PoseCostMetric:
return PoseCostMetric(
remove_offset_waypoint=True,
reach_full_pose=True,
release_partial_pose=True,
)
@@ -449,13 +150,88 @@ class PoseCost(CostBase, PoseCostConfig):
CostBase.__init__(self)
self.rot_weight = self.vec_weight[0:3]
self.pos_weight = self.vec_weight[3:6]
self._vec_convergence = self.tensor_args.to_device(self.vec_convergence)
self._batch_size = 0
self._horizon = 0
def update_metric(self, metric: PoseCostMetric):
if metric.hold_partial_pose:
if metric.hold_vec_weight is None:
log_error("hold_vec_weight is required")
self.hold_partial_pose(metric.hold_vec_weight)
if metric.release_partial_pose:
self.release_partial_pose()
if metric.reach_partial_pose:
if metric.reach_vec_weight is None:
log_error("reach_vec_weight is required")
self.reach_partial_pose(metric.reach_vec_weight)
if metric.reach_full_pose:
self.reach_full_pose()
if metric.remove_offset_waypoint:
self.remove_offset_waypoint()
if metric.offset_position is not None or metric.offset_rotation is not None:
self.update_offset_waypoint(
offset_position=self.offset_position,
offset_rotation=self.offset_rotation,
offset_tstep_fraction=self.offset_tstep_fraction,
)
def hold_partial_pose(self, run_vec_weight: torch.Tensor):
self.run_vec_weight.copy_(run_vec_weight)
def release_partial_pose(self):
self.run_vec_weight[:] = 0.0
def reach_partial_pose(self, vec_weight: torch.Tensor):
self.vec_weight[:] = vec_weight
def reach_full_pose(self):
self.vec_weight[:] = 1.0
def update_offset_waypoint(
self,
offset_position: Optional[torch.Tensor] = None,
offset_rotation: Optional[torch.Tensor] = None,
offset_tstep_fraction: float = 0.75,
):
if offset_position is not None:
self.offset_waypoint[3:].copy_(offset_position)
if offset_rotation is not None:
self.offset_waypoint[:3].copy_(offset_rotation)
self.offset_tstep_fraction[:] = offset_tstep_fraction
if self._horizon <= 0:
print(self.weight)
log_error(
"Updating offset waypoint is only possible after initializing motion gen"
+ " run motion_gen.warmup() before adding offset_waypoint"
)
self.update_run_weight(run_tstep_fraction=offset_tstep_fraction)
def remove_offset_waypoint(self):
self.offset_tstep_fraction[:] = -1.0
self.update_run_weight()
def update_run_weight(
self, run_tstep_fraction: float = 0.0, run_weight: Optional[float] = None
):
if self._horizon == 1:
return
if run_weight is None:
run_weight = self.run_weight
active_steps = math.floor(self._horizon * run_tstep_fraction)
self._run_weight_vec[:, :active_steps] = 0
self._run_weight_vec[:, active_steps:-1] = run_weight
def update_batch_size(self, batch_size, horizon):
if batch_size != self._batch_size or horizon != self._horizon:
# print(self.weight)
# print(batch_size, horizon, self._batch_size, self._horizon)
# batch_size = b*h
self.out_distance = torch.zeros(
(batch_size, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
@@ -493,12 +269,16 @@ class PoseCost(CostBase, PoseCostConfig):
self._run_weight_vec = torch.ones(
(1, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
if self.terminal and self.run_weight is not None:
self._run_weight_vec[:, :-1] *= self.run_weight
if self.terminal and self.run_weight is not None and horizon > 1:
self._run_weight_vec[:, :-1] = self.run_weight
self._batch_size = batch_size
self._horizon = horizon
@property
def goalset_index_buffer(self):
return self.out_idx
def _forward_goal_distribution(self, ee_pos_batch, ee_rot_batch, ee_goal_pos, ee_goal_rot):
ee_goal_pos = ee_goal_pos.unsqueeze(1)
ee_goal_pos = ee_goal_pos.unsqueeze(1)
@@ -563,13 +343,13 @@ class PoseCost(CostBase, PoseCostConfig):
def _update_cost_type(self, ee_goal_pos, ee_pos_batch, num_goals):
d_g = len(ee_goal_pos.shape)
b_sze = ee_goal_pos.shape[0]
if d_g == 2 and b_sze == 1: # 1, 3
if d_g == 2 and b_sze == 1: # [1, 3]
self.cost_type = PoseErrorType.SINGLE_GOAL
elif d_g == 2 and b_sze == ee_pos_batch.shape[0]: # b, 3
elif d_g == 2 and b_sze > 1: # [b, 3]
self.cost_type = PoseErrorType.BATCH_GOAL
elif d_g == 3:
elif d_g == 3 and b_sze == 1: # [1, goalset, 3]
self.cost_type = PoseErrorType.GOALSET
elif len(ee_goal_pos.shape) == 4 and b_sze == ee_pos_bath.shape[0]:
elif d_g == 3 and b_sze > 1: # [b, goalset,3]
self.cost_type = PoseErrorType.BATCH_GOALSET
def forward_out_distance(
@@ -599,6 +379,8 @@ class PoseCost(CostBase, PoseCostConfig):
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
self.offset_waypoint,
self.offset_tstep_fraction,
goal.batch_pose_idx,
self.out_distance,
self.out_position_distance,
@@ -613,7 +395,9 @@ class PoseCost(CostBase, PoseCostConfig):
self.cost_type.value,
num_goals,
self.use_metric,
self.project_distance,
)
# print(self.out_idx.shape, self.out_idx[:,-1])
# print(goal.batch_pose_idx.shape)
cost = distance # .view(b, h)#.clone()
r_err = r_err # .view(b, h)
@@ -632,65 +416,46 @@ class PoseCost(CostBase, PoseCostConfig):
ee_goal_rot = goal_pose.quaternion
num_goals = goal_pose.n_goalset
self._update_cost_type(ee_goal_pos, ee_pos_batch, num_goals)
# print(self.cost_type)
b, h, _ = ee_pos_batch.shape
self.update_batch_size(b, h)
# return self.out_distance
# print(b,h, ee_goal_pos.shape)
if self.return_loss:
distance = PoseLoss.apply(
ee_pos_batch,
ee_goal_pos,
ee_rot_batch, # .view(-1, 4).contiguous(),
ee_goal_rot,
self.vec_weight,
self.weight,
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
goal.batch_pose_idx,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
self.out_p_vec,
self.out_q_vec,
self.out_idx,
self.out_p_grad,
self.out_q_grad,
b,
h,
self.cost_type.value,
num_goals,
self.use_metric,
)
else:
distance = PoseError.apply(
ee_pos_batch,
ee_goal_pos,
ee_rot_batch, # .view(-1, 4).contiguous(),
ee_goal_rot,
self.vec_weight,
self.weight,
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
goal.batch_pose_idx,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
self.out_p_vec,
self.out_q_vec,
self.out_idx,
self.out_p_grad,
self.out_q_grad,
b,
h,
self.cost_type.value,
num_goals,
self.use_metric,
)
distance = PoseError.apply(
ee_pos_batch,
ee_goal_pos,
ee_rot_batch, # .view(-1, 4).contiguous(),
ee_goal_rot,
self.vec_weight,
self.weight,
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
self.offset_waypoint,
self.offset_tstep_fraction,
goal.batch_pose_idx,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
self.out_p_vec,
self.out_q_vec,
self.out_idx,
self.out_p_grad,
self.out_q_grad,
b,
h,
self.cost_type.value,
num_goals,
self.use_metric,
self.project_distance,
self.return_loss,
)
cost = distance
# if link_name is None and cost.shape[0]==8:
# print(ee_pos_batch[...,-1].squeeze())
# print(cost.shape)
return cost
def forward_pose(
@@ -708,56 +473,34 @@ class PoseCost(CostBase, PoseCostConfig):
b = query_pose.position.shape[0]
h = query_pose.position.shape[1]
num_goals = 1
if self.return_loss:
distance = PoseLoss.apply(
query_pose.position.unsqueeze(1),
ee_goal_pos,
query_pose.quaternion.unsqueeze(1),
ee_goal_quat,
self.vec_weight,
self.weight,
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
batch_pose_idx,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
self.out_p_vec,
self.out_q_vec,
self.out_idx,
self.out_p_grad,
self.out_q_grad,
b,
h,
self.cost_type.value,
num_goals,
self.use_metric,
)
else:
distance = PoseError.apply(
query_pose.position.unsqueeze(1),
ee_goal_pos,
query_pose.quaternion.unsqueeze(1),
ee_goal_quat,
self.vec_weight,
self.weight,
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
batch_pose_idx,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
self.out_p_vec,
self.out_q_vec,
self.out_idx,
self.out_p_grad,
self.out_q_grad,
b,
h,
self.cost_type.value,
num_goals,
self.use_metric,
)
distance = PoseError.apply(
query_pose.position.unsqueeze(1),
ee_goal_pos,
query_pose.quaternion.unsqueeze(1),
ee_goal_quat,
self.vec_weight,
self.weight,
self._vec_convergence,
self._run_weight_vec,
self.run_vec_weight,
self.offset_waypoint,
self.offset_tstep_fraction,
batch_pose_idx,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
self.out_p_vec,
self.out_q_vec,
self.out_idx,
self.out_p_grad,
self.out_q_grad,
b,
h,
self.cost_type.value,
num_goals,
self.use_metric,
self.project_distance,
self.return_loss,
)
return distance

View File

@@ -68,9 +68,14 @@ class StopCost(CostBase, StopCostConfig):
self.max_vel = (sum_matrix @ delta_vel).unsqueeze(-1)
def forward(self, vels):
vel_abs = torch.abs(vels)
vel_abs = torch.nn.functional.relu(vel_abs - self.max_vel)
cost = self.weight * (torch.sum(vel_abs**2, dim=-1))
cost = velocity_cost(vels, self.weight, self.max_vel)
return cost
@torch.jit.script
def velocity_cost(vels, weight, max_vel):
vel_abs = torch.abs(vels)
vel_abs = torch.nn.functional.relu(vel_abs - max_vel[: vels.shape[1]])
cost = weight * (torch.sum(vel_abs**2, dim=-1))
return cost