508 lines
17 KiB
Python
508 lines
17 KiB
Python
#
|
|
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
# property and proprietary rights in and to this material, related
|
|
# documentation and any modifications thereto. Any use, reproduction,
|
|
# disclosure or distribution of this material and related documentation
|
|
# without an express license agreement from NVIDIA CORPORATION or
|
|
# its affiliates is strictly prohibited.
|
|
#
|
|
# Standard Library
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Optional
|
|
|
|
# Third Party
|
|
import torch
|
|
import warp as wp
|
|
|
|
# CuRobo
|
|
from curobo.util.logger import log_error
|
|
from curobo.util.torch_utils import get_cache_fn_decorator, get_torch_jit_decorator
|
|
from curobo.util.warp import init_warp, warp_support_kernel_key
|
|
|
|
# Local Folder
|
|
from .cost_base import CostBase, CostConfig
|
|
|
|
wp.set_module_options({"fast_math": False})
|
|
|
|
|
|
class DistType(Enum):
|
|
L1 = 0
|
|
L2 = 1
|
|
SQUARED_L2 = 2
|
|
|
|
|
|
@dataclass
|
|
class DistCostConfig(CostConfig):
|
|
dist_type: DistType = DistType.L2
|
|
use_null_space: bool = False
|
|
use_l2_kernel: bool = True
|
|
|
|
def __post_init__(self):
|
|
return super().__post_init__()
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def L2_DistCost_jit(vec_weight, disp_vec):
|
|
return torch.norm(vec_weight * disp_vec, p=2, dim=-1, keepdim=False)
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def fwd_SQL2_DistCost_jit(vec_weight, disp_vec):
|
|
return torch.sum(torch.square(vec_weight * disp_vec), dim=-1, keepdim=False)
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def fwd_L1_DistCost_jit(vec_weight, disp_vec):
|
|
return torch.sum(torch.abs(vec_weight * disp_vec), dim=-1, keepdim=False)
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def L2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
|
return torch.norm(weight * vec_weight * (g_vec - c_vec), p=2, dim=-1, keepdim=False)
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def fwd_SQL2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
|
return torch.sum(torch.square(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def fwd_L1_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
|
return torch.sum(torch.abs(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
|
|
|
|
|
|
@wp.kernel
|
|
def forward_l2_warp(
|
|
pos: wp.array(dtype=wp.float32),
|
|
target: wp.array(dtype=wp.float32),
|
|
target_idx: wp.array(dtype=wp.int32),
|
|
weight: wp.array(dtype=wp.float32),
|
|
run_weight: wp.array(dtype=wp.float32),
|
|
vec_weight: wp.array(dtype=wp.float32),
|
|
out_cost: wp.array(dtype=wp.float32),
|
|
out_grad_p: wp.array(dtype=wp.float32),
|
|
write_grad: wp.uint8, # this should be a bool
|
|
batch_size: wp.int32,
|
|
horizon: wp.int32,
|
|
dof: wp.int32,
|
|
):
|
|
tid = wp.tid()
|
|
# initialize variables:
|
|
b_id = wp.int32(0)
|
|
h_id = wp.int32(0)
|
|
d_id = wp.int32(0)
|
|
b_addrs = wp.int32(0)
|
|
target_id = wp.int32(0)
|
|
w = wp.float32(0.0)
|
|
c_p = wp.float32(0.0)
|
|
target_p = wp.float32(0.0)
|
|
g_p = wp.float32(0.0)
|
|
r_w = wp.float32(0.0)
|
|
c_total = wp.float32(0.0)
|
|
|
|
# we launch batch * horizon * dof kernels
|
|
b_id = tid / (horizon * dof)
|
|
h_id = (tid - (b_id * horizon * dof)) / dof
|
|
d_id = tid - (b_id * horizon * dof + h_id * dof)
|
|
|
|
if b_id >= batch_size or h_id >= horizon or d_id >= dof:
|
|
return
|
|
|
|
# read weights:
|
|
w = weight[0]
|
|
r_w = run_weight[h_id]
|
|
w = r_w * w
|
|
r_w = vec_weight[d_id]
|
|
w = r_w * w
|
|
if w == 0.0:
|
|
return
|
|
# compute cost:
|
|
b_addrs = b_id * horizon * dof + h_id * dof + d_id
|
|
|
|
# read buffers:
|
|
|
|
c_p = pos[b_addrs]
|
|
target_id = target_idx[b_id]
|
|
target_id = target_id * dof + d_id
|
|
target_p = target[target_id]
|
|
error = c_p - target_p
|
|
|
|
# if r_w >= 1.0 and w > 100.0:
|
|
# c_total = w * wp.log2(wp.cosh(10.0 * error))
|
|
# g_p = w * 10.0 * wp.sinh(10.0 * error) / (wp.cosh(10.0 * error))
|
|
# else:
|
|
c_total = w * error * error
|
|
g_p = 2.0 * w * error
|
|
|
|
out_cost[b_addrs] = c_total
|
|
|
|
# compute gradient
|
|
if write_grad == 1:
|
|
out_grad_p[b_addrs] = g_p
|
|
|
|
|
|
@get_cache_fn_decorator()
|
|
def make_l2_kernel(dof_template: int):
|
|
def forward_l2_loop_warp(
|
|
pos: wp.array(dtype=wp.float32),
|
|
target: wp.array(dtype=wp.float32),
|
|
target_idx: wp.array(dtype=wp.int32),
|
|
weight: wp.array(dtype=wp.float32),
|
|
run_weight: wp.array(dtype=wp.float32),
|
|
vec_weight: wp.array(dtype=wp.float32),
|
|
out_cost: wp.array(dtype=wp.float32),
|
|
out_grad_p: wp.array(dtype=wp.float32),
|
|
write_grad: wp.uint8, # this should be a bool
|
|
batch_size: wp.int32,
|
|
horizon: wp.int32,
|
|
dof: wp.int32,
|
|
):
|
|
tid = wp.tid()
|
|
# initialize variables:
|
|
b_id = wp.int32(0)
|
|
h_id = wp.int32(0)
|
|
b_addrs = wp.int32(0)
|
|
target_id = wp.int32(0)
|
|
w = wp.float32(0.0)
|
|
r_w = wp.float32(0.0)
|
|
c_total = wp.float32(0.0)
|
|
|
|
# we launch batch * horizon * dof kernels
|
|
b_id = tid / (horizon)
|
|
h_id = tid - (b_id * horizon)
|
|
|
|
if b_id >= batch_size or h_id >= horizon:
|
|
return
|
|
|
|
# read weights:
|
|
w = weight[0]
|
|
r_w = run_weight[h_id]
|
|
w = r_w * w
|
|
|
|
if w == 0.0:
|
|
return
|
|
# compute cost:
|
|
b_addrs = b_id * horizon * dof + h_id * dof
|
|
|
|
# read buffers:
|
|
current_position = wp.vector(dtype=wp.float32, length=dof_template)
|
|
target_position = wp.vector(dtype=wp.float32, length=dof_template)
|
|
vec_weight_local = wp.vector(dtype=wp.float32, length=dof_template)
|
|
target_id = target_idx[b_id]
|
|
target_id = target_id * dof
|
|
|
|
for i in range(dof_template):
|
|
current_position[i] = pos[b_addrs + i]
|
|
target_position[i] = target[target_id + i]
|
|
vec_weight_local[i] = vec_weight[i]
|
|
|
|
error = wp.cw_mul(vec_weight_local, (current_position - target_position))
|
|
|
|
c_length = wp.length(error)
|
|
if w > 100.0:
|
|
p_w_alpha = 70.0
|
|
c_total = w * wp.log2(wp.cosh(p_w_alpha * c_length))
|
|
g_p = error * (
|
|
w
|
|
* p_w_alpha
|
|
* wp.sinh(p_w_alpha * c_length)
|
|
/ (c_length * wp.cosh(p_w_alpha * c_length))
|
|
)
|
|
else:
|
|
g_p = w * error
|
|
if c_length > 0.0:
|
|
g_p = g_p / c_length
|
|
c_total = w * c_length
|
|
|
|
out_cost[b_id * horizon + h_id] = c_total
|
|
|
|
# compute gradient
|
|
if write_grad == 1:
|
|
for i in range(dof_template):
|
|
out_grad_p[b_addrs + i] = g_p[i]
|
|
|
|
module = wp.get_module(forward_l2_loop_warp.__module__)
|
|
key = "forward_l2_loop" + str(dof_template)
|
|
new_kernel = wp.Kernel(forward_l2_loop_warp, key=key, module=module)
|
|
return new_kernel
|
|
|
|
|
|
# create a bound cost tensor:
|
|
class L2DistFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
pos,
|
|
target,
|
|
target_idx,
|
|
weight,
|
|
run_weight,
|
|
vec_weight,
|
|
out_cost,
|
|
out_cost_v,
|
|
out_gp,
|
|
):
|
|
wp_device = wp.device_from_torch(pos.device)
|
|
b, h, dof = pos.shape
|
|
requires_grad = pos.requires_grad
|
|
wp.launch(
|
|
kernel=forward_l2_warp,
|
|
dim=b * h * dof,
|
|
inputs=[
|
|
wp.from_torch(pos.detach().reshape(-1), dtype=wp.float32),
|
|
wp.from_torch(target.view(-1), dtype=wp.float32),
|
|
wp.from_torch(target_idx.view(-1), dtype=wp.int32),
|
|
wp.from_torch(weight, dtype=wp.float32),
|
|
wp.from_torch(run_weight.view(-1), dtype=wp.float32),
|
|
wp.from_torch(vec_weight.view(-1), dtype=wp.float32),
|
|
wp.from_torch(out_cost_v.view(-1), dtype=wp.float32),
|
|
wp.from_torch(out_gp.view(-1), dtype=wp.float32),
|
|
requires_grad,
|
|
b,
|
|
h,
|
|
dof,
|
|
],
|
|
device=wp_device,
|
|
stream=wp.stream_from_torch(pos.device),
|
|
)
|
|
|
|
cost = torch.sum(out_cost_v, dim=-1)
|
|
ctx.save_for_backward(out_gp)
|
|
return cost
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out_cost):
|
|
(p_grad,) = ctx.saved_tensors
|
|
|
|
p_g = None
|
|
if ctx.needs_input_grad[0]:
|
|
p_g = p_grad
|
|
return p_g, None, None, None, None, None, None, None, None
|
|
|
|
|
|
# create a bound cost tensor:
|
|
class L2DistLoopFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
pos,
|
|
target,
|
|
target_idx,
|
|
weight,
|
|
run_weight,
|
|
vec_weight,
|
|
out_cost,
|
|
out_cost_v,
|
|
out_gp,
|
|
l2_dof_kernel,
|
|
):
|
|
wp_device = wp.device_from_torch(pos.device)
|
|
b, h, dof = pos.shape
|
|
wp.launch(
|
|
kernel=l2_dof_kernel,
|
|
dim=b * h,
|
|
inputs=[
|
|
wp.from_torch(pos.detach().view(-1).contiguous(), dtype=wp.float32),
|
|
wp.from_torch(target.view(-1), dtype=wp.float32),
|
|
wp.from_torch(target_idx.view(-1), dtype=wp.int32),
|
|
wp.from_torch(weight, dtype=wp.float32),
|
|
wp.from_torch(run_weight.view(-1), dtype=wp.float32),
|
|
wp.from_torch(vec_weight.view(-1), dtype=wp.float32),
|
|
wp.from_torch(out_cost.view(-1), dtype=wp.float32),
|
|
wp.from_torch(out_gp.view(-1), dtype=wp.float32),
|
|
pos.requires_grad,
|
|
b,
|
|
h,
|
|
dof,
|
|
],
|
|
device=wp_device,
|
|
stream=wp.stream_from_torch(pos.device),
|
|
)
|
|
|
|
ctx.save_for_backward(out_gp)
|
|
return out_cost
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out_cost):
|
|
(p_grad,) = ctx.saved_tensors
|
|
|
|
p_g = None
|
|
if ctx.needs_input_grad[0]:
|
|
p_g = p_grad
|
|
return p_g, None, None, None, None, None, None, None, None, None
|
|
|
|
|
|
class DistCost(CostBase, DistCostConfig):
|
|
def __init__(self, config: Optional[DistCostConfig] = None):
|
|
if config is not None:
|
|
DistCostConfig.__init__(self, **vars(config))
|
|
CostBase.__init__(self)
|
|
self._init_post_config()
|
|
init_warp()
|
|
if self.use_l2_kernel:
|
|
if not warp_support_kernel_key():
|
|
# define a compile-time constant so that warp hash is different for different dof
|
|
# this is required in older warp versions < 1.2.1 as warp hash didn't consider the
|
|
# name of kernels. Newer warp versions have fixed this issue.
|
|
WARP_CUROBO_DISTCOST_DOF_GLOBAL_CONSTANT = wp.constant(self.dof)
|
|
|
|
self._l2_dof_kernel = make_l2_kernel(self.dof)
|
|
|
|
def _init_post_config(self):
|
|
if self.vec_weight is not None:
|
|
self.vec_weight = self.tensor_args.to_device(self.vec_weight)
|
|
if not self.use_null_space:
|
|
self.vec_weight = self.vec_weight * 0.0 + 1.0
|
|
|
|
def update_batch_size(self, batch, horizon, dof):
|
|
if dof != self.dof:
|
|
log_error("dof cannot be changed after initializing DistCost")
|
|
if self._batch_size != batch or self._horizon != horizon or self._dof != dof:
|
|
self._out_c_buffer = None
|
|
self._out_cv_buffer = None
|
|
if self.use_l2_kernel:
|
|
self._out_c_buffer = torch.zeros(
|
|
(batch, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
|
)
|
|
else:
|
|
self._out_cv_buffer = torch.zeros(
|
|
(batch, horizon, dof),
|
|
device=self.tensor_args.device,
|
|
dtype=self.tensor_args.dtype,
|
|
)
|
|
|
|
self._out_g_buffer = torch.zeros(
|
|
(batch, horizon, dof), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
|
)
|
|
|
|
self._batch_size = batch
|
|
self._horizon = horizon
|
|
self._dof = dof
|
|
if self.vec_weight is None:
|
|
self.vec_weight = torch.ones(
|
|
(1, 1, self._dof), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
|
)
|
|
|
|
def forward(self, disp_vec, RETURN_GOAL_DIST=False):
|
|
if self.dist_type == DistType.L2:
|
|
# dist = torch.norm(disp_vec, p=2, dim=-1, keepdim=False)
|
|
dist = L2_DistCost_jit(self.vec_weight, disp_vec)
|
|
elif self.dist_type == DistType.SQUARED_L2:
|
|
# cost = weight * (0.5 * torch.square(torch.norm(disp_vec, p=2, dim=-1)))
|
|
# dist = torch.sum(torch.square(disp_vec), dim=-1, keepdim=False)
|
|
dist = SQL2_DistCost_jit(self.vec_weight, disp_vec)
|
|
elif self.dist_type == DistType.L1:
|
|
# dist = torch.sum(torch.abs(disp_vec), dim=-1, keepdim=False)
|
|
dist = L1_DistCost_jit(self.vec_weight, disp_vec)
|
|
|
|
cost = self.weight * dist
|
|
if self.terminal and self.run_weight is not None:
|
|
if self._run_weight_vec is None or self._run_weight_vec.shape[1] != cost.shape[1]:
|
|
self._run_weight_vec = torch.ones(
|
|
(1, cost.shape[1]), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
|
)
|
|
self._run_weight_vec[:, :-1] *= self.run_weight
|
|
if RETURN_GOAL_DIST:
|
|
return cost, dist
|
|
return cost
|
|
|
|
def forward_target(self, goal_vec, current_vec, RETURN_GOAL_DIST=False):
|
|
if self.dist_type == DistType.L2:
|
|
# dist = torch.norm(disp_vec, p=2, dim=-1, keepdim=False)
|
|
cost = L2_DistCost_target_jit(self.vec_weight, goal_vec, current_vec, self.weight)
|
|
elif self.dist_type == DistType.SQUARED_L2:
|
|
# cost = weight * (0.5 * torch.square(torch.norm(disp_vec, p=2, dim=-1)))
|
|
# dist = torch.sum(torch.square(disp_vec), dim=-1, keepdim=False)
|
|
cost = fwd_SQL2_DistCost_target_jit(self.vec_weight, goal_vec, current_vec, self.weight)
|
|
elif self.dist_type == DistType.L1:
|
|
# dist = torch.sum(torch.abs(disp_vec), dim=-1, keepdim=False)
|
|
cost = fwd_L1_DistCost_target_jit(self.vec_weight, goal_vec, current_vec, self.weight)
|
|
dist = cost
|
|
if self.terminal and self.run_weight is not None:
|
|
if self._run_weight_vec is None or self._run_weight_vec.shape[1] != cost.shape[1]:
|
|
self._run_weight_vec = torch.ones(
|
|
(1, cost.shape[1]), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
|
)
|
|
self._run_weight_vec[:, :-1] *= self.run_weight
|
|
cost = self._run_weight_vec * dist
|
|
if RETURN_GOAL_DIST:
|
|
dist_scale = torch.nan_to_num(
|
|
1.0 / torch.sqrt((self.weight * self._run_weight_vec)), 0.0
|
|
)
|
|
|
|
return cost, dist * dist_scale
|
|
return cost
|
|
|
|
def forward_target_idx(self, goal_vec, current_vec, goal_idx, RETURN_GOAL_DIST=False):
|
|
b, h, dof = current_vec.shape
|
|
self.update_batch_size(b, h, dof)
|
|
|
|
if self.terminal and self.run_weight is not None:
|
|
if self._run_weight_vec is None or self._run_weight_vec.shape[1] != h:
|
|
self._run_weight_vec = torch.ones(
|
|
(1, h), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
|
)
|
|
self._run_weight_vec[:, :-1] *= self.run_weight
|
|
else:
|
|
raise NotImplementedError("terminal flag needs to be set to true")
|
|
if self.dist_type == DistType.L2:
|
|
if self.use_l2_kernel:
|
|
|
|
cost = L2DistLoopFunction.apply(
|
|
current_vec,
|
|
goal_vec,
|
|
goal_idx,
|
|
self.weight,
|
|
self._run_weight_vec,
|
|
self.vec_weight,
|
|
self._out_c_buffer,
|
|
None,
|
|
self._out_g_buffer,
|
|
self._l2_dof_kernel,
|
|
)
|
|
else:
|
|
cost = L2DistFunction.apply(
|
|
current_vec,
|
|
goal_vec,
|
|
goal_idx,
|
|
self.weight,
|
|
self._run_weight_vec,
|
|
self.vec_weight,
|
|
None,
|
|
self._out_cv_buffer,
|
|
self._out_g_buffer,
|
|
)
|
|
|
|
else:
|
|
raise NotImplementedError()
|
|
if RETURN_GOAL_DIST:
|
|
|
|
if self.use_l2_kernel:
|
|
distance = weight_cost_to_l2_jit(cost, self.weight, self._run_weight_vec)
|
|
else:
|
|
distance = squared_cost_to_l2_jit(cost, self.weight, self._run_weight_vec)
|
|
return cost, distance
|
|
|
|
return cost
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def squared_cost_to_l2_jit(cost, weight, run_weight_vec):
|
|
weight_inv = weight * run_weight_vec
|
|
weight_inv = 1.0 / weight_inv
|
|
weight_inv = torch.nan_to_num(weight_inv, 0.0)
|
|
distance = torch.sqrt(cost * weight_inv)
|
|
return distance
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def weight_cost_to_l2_jit(cost, weight, run_weight_vec):
|
|
weight_inv = weight * run_weight_vec
|
|
weight_inv = 1.0 / weight_inv
|
|
weight_inv = torch.nan_to_num(weight_inv, 0.0)
|
|
distance = cost * weight_inv
|
|
return distance
|