improved joint space planning

This commit is contained in:
Balakumar Sundaralingam
2024-05-30 14:42:22 -07:00
parent 3bfed9d773
commit 0c51dd2da8
28 changed files with 1135 additions and 213 deletions

View File

@@ -18,7 +18,8 @@ import torch
import warp as wp
# CuRobo
from curobo.util.torch_utils import get_torch_jit_decorator
from curobo.util.logger import log_error
from curobo.util.torch_utils import get_cache_fn_decorator, get_torch_jit_decorator
from curobo.util.warp import init_warp
# Local Folder
@@ -37,6 +38,7 @@ class DistType(Enum):
class DistCostConfig(CostConfig):
dist_type: DistType = DistType.L2
use_null_space: bool = False
use_l2_kernel: bool = True
def __post_init__(self):
return super().__post_init__()
@@ -142,6 +144,89 @@ def forward_l2_warp(
out_grad_p[b_addrs] = g_p
@get_cache_fn_decorator()
def make_l2_kernel(dof_template: int):
def forward_l2_loop_warp(
pos: wp.array(dtype=wp.float32),
target: wp.array(dtype=wp.float32),
target_idx: wp.array(dtype=wp.int32),
weight: wp.array(dtype=wp.float32),
run_weight: wp.array(dtype=wp.float32),
vec_weight: wp.array(dtype=wp.float32),
out_cost: wp.array(dtype=wp.float32),
out_grad_p: wp.array(dtype=wp.float32),
write_grad: wp.uint8, # this should be a bool
batch_size: wp.int32,
horizon: wp.int32,
dof: wp.int32,
):
tid = wp.tid()
# initialize variables:
b_id = wp.int32(0)
h_id = wp.int32(0)
b_addrs = wp.int32(0)
target_id = wp.int32(0)
w = wp.float32(0.0)
r_w = wp.float32(0.0)
c_total = wp.float32(0.0)
# we launch batch * horizon * dof kernels
b_id = tid / (horizon)
h_id = tid - (b_id * horizon)
if b_id >= batch_size or h_id >= horizon:
return
# read weights:
w = weight[0]
r_w = run_weight[h_id]
w = r_w * w
if w == 0.0:
return
# compute cost:
b_addrs = b_id * horizon * dof + h_id * dof
# read buffers:
current_position = wp.vector(dtype=wp.float32, length=dof_template)
target_position = wp.vector(dtype=wp.float32, length=dof_template)
vec_weight_local = wp.vector(dtype=wp.float32, length=dof_template)
target_id = target_idx[b_id]
target_id = target_id * dof
for i in range(dof_template):
current_position[i] = pos[b_addrs + i]
target_position[i] = target[target_id + i]
vec_weight_local[i] = vec_weight[i]
error = wp.cw_mul(vec_weight_local, (current_position - target_position))
c_length = wp.length(error)
if w > 100.0:
p_w_alpha = 70.0
c_total = w * wp.log2(wp.cosh(p_w_alpha * c_length))
g_p = error * (
w
* p_w_alpha
* wp.sinh(p_w_alpha * c_length)
/ (c_length * wp.cosh(p_w_alpha * c_length))
)
else:
g_p = w * error
if c_length > 0.0:
g_p = g_p / c_length
c_total = w * c_length
out_cost[b_id * horizon + h_id] = c_total
# compute gradient
if write_grad == 1:
for i in range(dof_template):
out_grad_p[b_addrs + i] = g_p[i]
return wp.Kernel(forward_l2_loop_warp)
# create a bound cost tensor:
class L2DistFunction(torch.autograd.Function):
@staticmethod
@@ -195,6 +280,58 @@ class L2DistFunction(torch.autograd.Function):
return p_g, None, None, None, None, None, None, None, None
# create a bound cost tensor:
class L2DistLoopFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
pos,
target,
target_idx,
weight,
run_weight,
vec_weight,
out_cost,
out_cost_v,
out_gp,
l2_dof_kernel,
):
wp_device = wp.device_from_torch(pos.device)
b, h, dof = pos.shape
wp.launch(
kernel=l2_dof_kernel,
dim=b * h,
inputs=[
wp.from_torch(pos.detach().view(-1).contiguous(), dtype=wp.float32),
wp.from_torch(target.view(-1), dtype=wp.float32),
wp.from_torch(target_idx.view(-1), dtype=wp.int32),
wp.from_torch(weight, dtype=wp.float32),
wp.from_torch(run_weight.view(-1), dtype=wp.float32),
wp.from_torch(vec_weight.view(-1), dtype=wp.float32),
wp.from_torch(out_cost.view(-1), dtype=wp.float32),
wp.from_torch(out_gp.view(-1), dtype=wp.float32),
pos.requires_grad,
b,
h,
dof,
],
device=wp_device,
stream=wp.stream_from_torch(pos.device),
)
ctx.save_for_backward(out_gp)
return out_cost
@staticmethod
def backward(ctx, grad_out_cost):
(p_grad,) = ctx.saved_tensors
p_g = None
if ctx.needs_input_grad[0]:
p_g = p_grad
return p_g, None, None, None, None, None, None, None, None, None
class DistCost(CostBase, DistCostConfig):
def __init__(self, config: Optional[DistCostConfig] = None):
if config is not None:
@@ -202,6 +339,8 @@ class DistCost(CostBase, DistCostConfig):
CostBase.__init__(self)
self._init_post_config()
init_warp()
if self.use_l2_kernel:
self._l2_dof_kernel = make_l2_kernel(self.dof)
def _init_post_config(self):
if self.vec_weight is not None:
@@ -210,13 +349,21 @@ class DistCost(CostBase, DistCostConfig):
self.vec_weight = self.vec_weight * 0.0 + 1.0
def update_batch_size(self, batch, horizon, dof):
if dof != self.dof:
log_error("dof cannot be changed after initializing DistCost")
if self._batch_size != batch or self._horizon != horizon or self._dof != dof:
self._out_cv_buffer = torch.zeros(
(batch, horizon, dof), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self._out_c_buffer = torch.zeros(
(batch, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self._out_c_buffer = None
self._out_cv_buffer = None
if self.use_l2_kernel:
self._out_c_buffer = torch.zeros(
(batch, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
else:
self._out_cv_buffer = torch.zeros(
(batch, horizon, dof),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
self._out_g_buffer = torch.zeros(
(batch, horizon, dof), device=self.tensor_args.device, dtype=self.tensor_args.dtype
@@ -293,24 +440,59 @@ class DistCost(CostBase, DistCostConfig):
else:
raise NotImplementedError("terminal flag needs to be set to true")
if self.dist_type == DistType.L2:
# print(goal_idx.shape, goal_vec.shape)
cost = L2DistFunction.apply(
current_vec,
goal_vec,
goal_idx,
self.weight,
self._run_weight_vec,
self.vec_weight,
self._out_c_buffer,
self._out_cv_buffer,
self._out_g_buffer,
)
if self.use_l2_kernel:
cost = L2DistLoopFunction.apply(
current_vec,
goal_vec,
goal_idx,
self.weight,
self._run_weight_vec,
self.vec_weight,
self._out_c_buffer,
None,
self._out_g_buffer,
self._l2_dof_kernel,
)
else:
cost = L2DistFunction.apply(
current_vec,
goal_vec,
goal_idx,
self.weight,
self._run_weight_vec,
self.vec_weight,
None,
self._out_cv_buffer,
self._out_g_buffer,
)
else:
raise NotImplementedError()
if RETURN_GOAL_DIST:
dist_scale = torch.nan_to_num(
1.0 / torch.sqrt((self.weight * self._run_weight_vec)), 0.0
)
return cost, cost * dist_scale
if self.use_l2_kernel:
distance = weight_cost_to_l2_jit(cost, self.weight, self._run_weight_vec)
else:
distance = squared_cost_to_l2_jit(cost, self.weight, self._run_weight_vec)
return cost, distance
return cost
@get_torch_jit_decorator()
def squared_cost_to_l2_jit(cost, weight, run_weight_vec):
weight_inv = weight * run_weight_vec
weight_inv = 1.0 / weight_inv
weight_inv = torch.nan_to_num(weight_inv, 0.0)
distance = torch.sqrt(cost * weight_inv)
return distance
@get_torch_jit_decorator()
def weight_cost_to_l2_jit(cost, weight, run_weight_vec):
weight_inv = weight * run_weight_vec
weight_inv = 1.0 / weight_inv
weight_inv = torch.nan_to_num(weight_inv, 0.0)
distance = cost * weight_inv
return distance