improved joint space planning
This commit is contained in:
@@ -261,7 +261,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
log_warn(
|
||||
"null space cost is deprecated, use null_space_weight in bound cost instead"
|
||||
)
|
||||
|
||||
self.cost_cfg.bound_cfg.dof = self.n_dofs
|
||||
self.bound_cost = BoundCost(self.cost_cfg.bound_cfg)
|
||||
|
||||
if self.cost_cfg.manipulability_cfg is not None:
|
||||
@@ -315,10 +315,12 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
self.cost_cfg.bound_cfg.state_finite_difference_mode = (
|
||||
self.dynamics_model.state_finite_difference_mode
|
||||
)
|
||||
|
||||
self.cost_cfg.bound_cfg.dof = self.n_dofs
|
||||
self.constraint_cfg.bound_cfg.dof = self.n_dofs
|
||||
self.bound_constraint = BoundCost(self.constraint_cfg.bound_cfg)
|
||||
|
||||
if self.convergence_cfg.null_space_cfg is not None:
|
||||
self.convergence_cfg.null_space_cfg.dof = self.n_dofs
|
||||
self.null_convergence = DistCost(self.convergence_cfg.null_space_cfg)
|
||||
|
||||
# set start state:
|
||||
@@ -578,6 +580,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
----------
|
||||
action_seq: torch.Tensor [num_particles, horizon, d_act]
|
||||
"""
|
||||
|
||||
# print(act_seq.shape, self._goal_buffer.batch_current_state_idx)
|
||||
if self.start_state is None:
|
||||
raise ValueError("start_state is not set in rollout")
|
||||
@@ -585,6 +588,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
state = self.dynamics_model.forward(
|
||||
self.start_state, act_seq, self._goal_buffer.batch_current_state_idx
|
||||
)
|
||||
|
||||
with profiler.record_function("cost/all"):
|
||||
cost_seq = self.cost_fn(state, act_seq)
|
||||
|
||||
|
||||
@@ -174,6 +174,7 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
self._n_goalset = 1
|
||||
|
||||
if self.cost_cfg.cspace_cfg is not None:
|
||||
self.cost_cfg.cspace_cfg.dof = self.d_action
|
||||
# self.cost_cfg.cspace_cfg.update_vec_weight(self.dynamics_model.cspace_distance_weight)
|
||||
self.dist_cost = DistCost(self.cost_cfg.cspace_cfg)
|
||||
if self.cost_cfg.pose_cfg is not None:
|
||||
@@ -226,6 +227,7 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
if i != self.kinematics.ee_link:
|
||||
self._link_pose_convergence[i] = PoseCost(self.convergence_cfg.link_pose_cfg)
|
||||
if self.convergence_cfg.cspace_cfg is not None:
|
||||
self.convergence_cfg.cspace_cfg.dof = self.d_action
|
||||
self.cspace_convergence = DistCost(self.convergence_cfg.cspace_cfg)
|
||||
|
||||
# check if g_dist is required in any of the cost terms:
|
||||
|
||||
@@ -22,7 +22,7 @@ from curobo.cuda_robot_model.types import JointLimits
|
||||
from curobo.types.robot import JointState
|
||||
from curobo.types.tensor import T_DOF
|
||||
from curobo.util.logger import log_error
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.torch_utils import get_cache_fn_decorator, get_torch_jit_decorator
|
||||
from curobo.util.warp import init_warp
|
||||
|
||||
# Local Folder
|
||||
@@ -49,6 +49,7 @@ class BoundCostConfig(CostConfig):
|
||||
activation_distance: Union[torch.Tensor, float] = 0.0
|
||||
state_finite_difference_mode: str = "BACKWARD"
|
||||
null_space_weight: Optional[List[float]] = None
|
||||
use_l2_kernel: bool = False
|
||||
|
||||
def set_bounds(self, bounds: JointLimits, teleport_mode: bool = False):
|
||||
self.joint_limits = bounds.clone()
|
||||
@@ -104,12 +105,21 @@ class BoundCost(CostBase, BoundCostConfig):
|
||||
(0, 0, 0), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
self._out_gv_buffer = self._out_ga_buffer = self._out_gj_buffer = empty_buffer
|
||||
if self.use_l2_kernel:
|
||||
if self.cost_type == BoundCostType.POSITION:
|
||||
self._l2_cost = make_bound_pos_kernel(self.dof)
|
||||
if self.cost_type == BoundCostType.BOUNDS_SMOOTH:
|
||||
self._l2_cost = make_bound_pos_smooth_kernel(self.dof)
|
||||
|
||||
def update_batch_size(self, batch, horizon, dof):
|
||||
if self._batch_size != batch or self._horizon != horizon or self._dof != dof:
|
||||
self._out_c_buffer = torch.zeros(
|
||||
(batch, horizon, dof), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
self._out_c_sum_buffer = torch.zeros(
|
||||
(batch, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
|
||||
self._out_gp_buffer = torch.zeros(
|
||||
(batch, horizon, dof), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
@@ -184,36 +194,61 @@ class BoundCost(CostBase, BoundCostConfig):
|
||||
retract_idx = self._retract_cfg_idx
|
||||
|
||||
if self.cost_type == BoundCostType.BOUNDS_SMOOTH:
|
||||
# print(self.joint_limits.jerk.shape, self.joint_limits.position.shape)
|
||||
cost = WarpBoundSmoothFunction.apply(
|
||||
state_batch.position,
|
||||
state_batch.velocity,
|
||||
state_batch.acceleration,
|
||||
state_batch.jerk,
|
||||
retract_config,
|
||||
retract_idx,
|
||||
self.joint_limits.position,
|
||||
self.joint_limits.velocity,
|
||||
self.joint_limits.acceleration,
|
||||
self.joint_limits.jerk,
|
||||
self.weight,
|
||||
self.activation_distance,
|
||||
self.smooth_weight,
|
||||
self.cspace_distance_weight,
|
||||
self.null_space_weight,
|
||||
self.vec_weight,
|
||||
self._run_weight_vel,
|
||||
self._run_weight_acc,
|
||||
self._run_weight_jerk,
|
||||
self._out_c_buffer,
|
||||
self._out_gp_buffer,
|
||||
self._out_gv_buffer,
|
||||
self._out_ga_buffer,
|
||||
self._out_gj_buffer,
|
||||
)
|
||||
# print(self.cspace_distance_weight.shape)
|
||||
# print(cost)
|
||||
# print(self._run_weight_acc)
|
||||
if self.use_l2_kernel:
|
||||
cost = WarpBoundSmoothL2Function.apply(
|
||||
state_batch.position,
|
||||
state_batch.velocity,
|
||||
state_batch.acceleration,
|
||||
state_batch.jerk,
|
||||
retract_config,
|
||||
retract_idx,
|
||||
self.joint_limits.position,
|
||||
self.joint_limits.velocity,
|
||||
self.joint_limits.acceleration,
|
||||
self.joint_limits.jerk,
|
||||
self.weight,
|
||||
self.activation_distance,
|
||||
self.smooth_weight,
|
||||
self.cspace_distance_weight,
|
||||
self.null_space_weight,
|
||||
self.vec_weight,
|
||||
self._run_weight_vel,
|
||||
self._run_weight_acc,
|
||||
self._run_weight_jerk,
|
||||
self._out_c_sum_buffer,
|
||||
self._out_gp_buffer,
|
||||
self._out_gv_buffer,
|
||||
self._out_ga_buffer,
|
||||
self._out_gj_buffer,
|
||||
self._l2_cost,
|
||||
)
|
||||
else:
|
||||
cost = WarpBoundSmoothFunction.apply(
|
||||
state_batch.position,
|
||||
state_batch.velocity,
|
||||
state_batch.acceleration,
|
||||
state_batch.jerk,
|
||||
retract_config,
|
||||
retract_idx,
|
||||
self.joint_limits.position,
|
||||
self.joint_limits.velocity,
|
||||
self.joint_limits.acceleration,
|
||||
self.joint_limits.jerk,
|
||||
self.weight,
|
||||
self.activation_distance,
|
||||
self.smooth_weight,
|
||||
self.cspace_distance_weight,
|
||||
self.null_space_weight,
|
||||
self.vec_weight,
|
||||
self._run_weight_vel,
|
||||
self._run_weight_acc,
|
||||
self._run_weight_jerk,
|
||||
self._out_c_buffer,
|
||||
self._out_gp_buffer,
|
||||
self._out_gv_buffer,
|
||||
self._out_ga_buffer,
|
||||
self._out_gj_buffer,
|
||||
)
|
||||
elif self.cost_type == BoundCostType.BOUNDS:
|
||||
cost = WarpBoundFunction.apply(
|
||||
state_batch.position,
|
||||
@@ -237,8 +272,8 @@ class BoundCost(CostBase, BoundCostConfig):
|
||||
self._out_gj_buffer,
|
||||
)
|
||||
elif self.cost_type == BoundCostType.POSITION:
|
||||
if self.return_loss:
|
||||
cost = WarpBoundPosLoss.apply(
|
||||
if self.use_l2_kernel:
|
||||
cost = WarpBoundPosL2Function.apply(
|
||||
state_batch.position,
|
||||
retract_config,
|
||||
retract_idx,
|
||||
@@ -247,8 +282,10 @@ class BoundCost(CostBase, BoundCostConfig):
|
||||
self.activation_distance,
|
||||
self.null_space_weight,
|
||||
self.vec_weight,
|
||||
self._out_c_buffer,
|
||||
self._out_c_sum_buffer,
|
||||
self._out_gp_buffer,
|
||||
self._l2_cost,
|
||||
self.return_loss,
|
||||
)
|
||||
else:
|
||||
cost = WarpBoundPosFunction.apply(
|
||||
@@ -262,6 +299,7 @@ class BoundCost(CostBase, BoundCostConfig):
|
||||
self.vec_weight,
|
||||
self._out_c_buffer,
|
||||
self._out_gp_buffer,
|
||||
self.return_loss,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -458,6 +496,132 @@ class WarpBoundSmoothFunction(torch.autograd.Function):
|
||||
)
|
||||
|
||||
|
||||
# create a bound cost tensor:
|
||||
class WarpBoundSmoothL2Function(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
pos,
|
||||
vel,
|
||||
acc,
|
||||
jerk,
|
||||
retract_config,
|
||||
retract_idx,
|
||||
p_b,
|
||||
v_b,
|
||||
a_b,
|
||||
j_b,
|
||||
weight,
|
||||
activation_distance,
|
||||
smooth_weight,
|
||||
cspace_weight,
|
||||
null_space_weight,
|
||||
vec_weight,
|
||||
run_weight_vel,
|
||||
run_weight_acc,
|
||||
run_weight_jerk,
|
||||
out_cost,
|
||||
out_gp,
|
||||
out_gv,
|
||||
out_ga,
|
||||
out_gj,
|
||||
warp_function,
|
||||
):
|
||||
# scale the weights for smoothness by this dt:
|
||||
wp_device = wp.device_from_torch(vel.device)
|
||||
# assert smooth_weight.shape[0] == 7
|
||||
b, h, dof = vel.shape
|
||||
requires_grad = pos.requires_grad
|
||||
|
||||
wp.launch(
|
||||
kernel=warp_function,
|
||||
dim=b * h,
|
||||
inputs=[
|
||||
wp.from_torch(pos.detach().view(-1), dtype=wp.float32),
|
||||
wp.from_torch(vel.detach().view(-1), dtype=wp.float32),
|
||||
wp.from_torch(acc.detach().view(-1), dtype=wp.float32),
|
||||
wp.from_torch(jerk.detach().view(-1), dtype=wp.float32),
|
||||
wp.from_torch(retract_config.detach().view(-1), dtype=wp.float32),
|
||||
wp.from_torch(retract_idx.detach().view(-1), dtype=wp.int32),
|
||||
wp.from_torch(p_b.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(v_b.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(a_b.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(j_b.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(weight, dtype=wp.float32),
|
||||
wp.from_torch(activation_distance, dtype=wp.float32),
|
||||
wp.from_torch(smooth_weight, dtype=wp.float32),
|
||||
wp.from_torch(cspace_weight, dtype=wp.float32),
|
||||
wp.from_torch(null_space_weight.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(vec_weight.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(run_weight_vel.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(run_weight_acc.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(run_weight_jerk.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(out_cost.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(out_gp.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(out_gv.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(out_ga.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(out_gj.view(-1), dtype=wp.float32),
|
||||
requires_grad,
|
||||
b,
|
||||
h,
|
||||
dof,
|
||||
],
|
||||
device=wp_device,
|
||||
stream=wp.stream_from_torch(vel.device),
|
||||
)
|
||||
ctx.save_for_backward(out_gp, out_gv, out_ga, out_gj)
|
||||
|
||||
return out_cost
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out_cost):
|
||||
(
|
||||
p_grad,
|
||||
v_grad,
|
||||
a_grad,
|
||||
j_grad,
|
||||
) = ctx.saved_tensors
|
||||
v_g = None
|
||||
a_g = None
|
||||
p_g = None
|
||||
j_g = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
p_g = p_grad # * grad_out_cost#.unsqueeze(-1)
|
||||
if ctx.needs_input_grad[1]:
|
||||
v_g = v_grad # * grad_out_cost#.unsqueeze(-1)
|
||||
if ctx.needs_input_grad[2]:
|
||||
a_g = a_grad # * grad_out_cost#.unsqueeze(-1)
|
||||
if ctx.needs_input_grad[3]:
|
||||
j_g = j_grad # * grad_out_cost#.unsqueeze(-1)
|
||||
return (
|
||||
p_g,
|
||||
v_g,
|
||||
a_g,
|
||||
j_g,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class WarpBoundFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -581,6 +745,7 @@ class WarpBoundPosFunction(torch.autograd.Function):
|
||||
vec_weight,
|
||||
out_cost,
|
||||
out_gp,
|
||||
return_loss=False,
|
||||
):
|
||||
wp_device = wp.device_from_torch(pos.device)
|
||||
b, h, dof = pos.shape
|
||||
@@ -607,10 +772,9 @@ class WarpBoundPosFunction(torch.autograd.Function):
|
||||
device=wp_device,
|
||||
stream=wp.stream_from_torch(pos.device),
|
||||
)
|
||||
ctx.return_loss = return_loss
|
||||
ctx.save_for_backward(out_gp)
|
||||
# cost = torch.linalg.norm(out_cost, dim=-1)
|
||||
cost = torch.sum(out_cost, dim=-1)
|
||||
# cost = out_cost
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
@@ -619,20 +783,71 @@ class WarpBoundPosFunction(torch.autograd.Function):
|
||||
|
||||
p_g = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
p_g = p_grad # * grad_out_cost.unsqueeze(-1)
|
||||
return p_g, None, None, None, None, None, None, None, None, None
|
||||
p_g = p_grad
|
||||
if ctx.return_loss:
|
||||
p_g = p_grad * grad_out_cost.unsqueeze(-1)
|
||||
|
||||
return p_g, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
# create a bound cost tensor:
|
||||
class WarpBoundPosLoss(WarpBoundPosFunction):
|
||||
class WarpBoundPosL2Function(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
pos,
|
||||
retract_config,
|
||||
retract_idx,
|
||||
p_l,
|
||||
weight,
|
||||
activation_distance,
|
||||
null_space_weight,
|
||||
vec_weight,
|
||||
out_cost,
|
||||
out_gp,
|
||||
warp_function,
|
||||
return_loss=False,
|
||||
):
|
||||
wp_device = wp.device_from_torch(pos.device)
|
||||
b, h, dof = pos.shape
|
||||
requires_grad = pos.requires_grad
|
||||
wp.launch(
|
||||
kernel=warp_function,
|
||||
dim=b * h,
|
||||
inputs=[
|
||||
wp.from_torch(pos.detach().view(-1), dtype=wp.float32),
|
||||
wp.from_torch(retract_config.detach().view(-1), dtype=wp.float32),
|
||||
wp.from_torch(retract_idx.detach().view(-1), dtype=wp.int32),
|
||||
wp.from_torch(p_l.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(weight, dtype=wp.float32),
|
||||
wp.from_torch(activation_distance, dtype=wp.float32),
|
||||
wp.from_torch(null_space_weight.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(vec_weight.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(out_cost.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(out_gp.view(-1), dtype=wp.float32),
|
||||
requires_grad,
|
||||
b,
|
||||
h,
|
||||
dof,
|
||||
],
|
||||
device=wp_device,
|
||||
stream=wp.stream_from_torch(pos.device),
|
||||
)
|
||||
ctx.return_loss = return_loss
|
||||
ctx.save_for_backward(out_gp)
|
||||
# cost = torch.sum(out_cost, dim=-1)
|
||||
return out_cost
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out_cost):
|
||||
(p_grad,) = ctx.saved_tensors
|
||||
|
||||
p_g = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
p_g = p_grad * grad_out_cost.unsqueeze(-1)
|
||||
return p_g, None, None, None, None, None, None, None, None, None
|
||||
p_g = p_grad # * grad_out_cost.unsqueeze(-1)
|
||||
if ctx.return_loss:
|
||||
p_g = p_grad * grad_out_cost.unsqueeze(-1)
|
||||
return p_g, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
@wp.kernel
|
||||
@@ -1122,3 +1337,314 @@ def forward_bound_smooth_warp(
|
||||
out_grad_v[b_addrs] = g_v
|
||||
out_grad_a[b_addrs] = g_a
|
||||
out_grad_j[b_addrs] = g_j
|
||||
|
||||
|
||||
@get_cache_fn_decorator()
|
||||
def make_bound_pos_smooth_kernel(dof_template: int):
|
||||
def forward_bound_smooth_loop_warp(
|
||||
pos: wp.array(dtype=wp.float32),
|
||||
vel: wp.array(dtype=wp.float32),
|
||||
acc: wp.array(dtype=wp.float32),
|
||||
jerk: wp.array(dtype=wp.float32),
|
||||
retract_config: wp.array(dtype=wp.float32),
|
||||
retract_idx: wp.array(dtype=wp.int32),
|
||||
p_b: wp.array(dtype=wp.float32),
|
||||
v_b: wp.array(dtype=wp.float32),
|
||||
a_b: wp.array(dtype=wp.float32),
|
||||
j_b: wp.array(dtype=wp.float32),
|
||||
weight: wp.array(dtype=wp.float32),
|
||||
activation_distance: wp.array(dtype=wp.float32),
|
||||
smooth_weight: wp.array(dtype=wp.float32),
|
||||
cspace_weight: wp.array(dtype=wp.float32),
|
||||
null_weight: wp.array(dtype=wp.float32),
|
||||
vec_weight: wp.array(dtype=wp.float32),
|
||||
run_weight_vel: wp.array(dtype=wp.float32),
|
||||
run_weight_acc: wp.array(dtype=wp.float32),
|
||||
run_weight_jerk: wp.array(dtype=wp.float32),
|
||||
out_cost: wp.array(dtype=wp.float32),
|
||||
out_grad_p: wp.array(dtype=wp.float32),
|
||||
out_grad_v: wp.array(dtype=wp.float32),
|
||||
out_grad_a: wp.array(dtype=wp.float32),
|
||||
out_grad_j: wp.array(dtype=wp.float32),
|
||||
write_grad: wp.uint8, # this should be a bool
|
||||
batch_size: wp.int32,
|
||||
horizon: wp.int32,
|
||||
dof: wp.int32,
|
||||
):
|
||||
tid = wp.tid()
|
||||
# initialize variables:
|
||||
b_id = wp.int32(0)
|
||||
h_id = wp.int32(0)
|
||||
b_addrs = int(0)
|
||||
|
||||
w = wp.float32(0.0)
|
||||
b_wv = float(0.0)
|
||||
b_wa = float(0.0)
|
||||
b_wj = float(0.0)
|
||||
|
||||
r_wa = wp.float32(0.0)
|
||||
r_wj = wp.float32(0.0)
|
||||
|
||||
w_a = wp.float32(0.0)
|
||||
w_j = wp.float32(0.0)
|
||||
|
||||
s_a = wp.float32(0.0)
|
||||
s_j = wp.float32(0.0)
|
||||
|
||||
c_total = wp.float32(0.0)
|
||||
|
||||
# we launch batch * horizon * dof kernels
|
||||
b_id = tid / (horizon)
|
||||
h_id = tid - (b_id * horizon)
|
||||
if b_id >= batch_size or h_id >= horizon:
|
||||
return
|
||||
|
||||
n_w = wp.float32(0.0)
|
||||
n_w = null_weight[0]
|
||||
|
||||
vec_weight_local = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
cspace_weight_local = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
p_l = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
p_u = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
v_l = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
v_u = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
a_l = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
a_u = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
j_l = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
j_u = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
g_p = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
g_v = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
g_a = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
g_j = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
target_id = wp.int32(0.0)
|
||||
target_id = retract_idx[b_id]
|
||||
current_p = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
current_v = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
current_a = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
current_j = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
|
||||
target_p = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
|
||||
b_addrs = b_id * horizon * dof + h_id * dof
|
||||
|
||||
for i in range(dof_template):
|
||||
vec_weight_local[i] = vec_weight[i]
|
||||
target_p[i] = retract_config[target_id * dof + i]
|
||||
current_p[i] = pos[b_addrs + i]
|
||||
current_v[i] = vel[b_addrs + i]
|
||||
current_a[i] = acc[b_addrs + i]
|
||||
current_j[i] = jerk[b_addrs + i]
|
||||
p_l[i] = p_b[i]
|
||||
p_u[i] = p_b[dof + i]
|
||||
v_l[i] = v_b[i]
|
||||
v_u[i] = v_b[dof + i]
|
||||
a_l[i] = a_b[i]
|
||||
a_u[i] = a_b[dof + i]
|
||||
j_l[i] = j_b[i]
|
||||
j_u[i] = j_b[dof + i]
|
||||
cspace_weight_local[i] = cspace_weight[i]
|
||||
|
||||
# read weights:
|
||||
w = weight[0]
|
||||
b_wv = weight[1]
|
||||
b_wa = weight[2]
|
||||
b_wj = weight[3]
|
||||
r_wa = run_weight_acc[h_id]
|
||||
r_wj = run_weight_jerk[h_id]
|
||||
|
||||
w_a = smooth_weight[1]
|
||||
w_j = smooth_weight[2]
|
||||
|
||||
# compute cost:
|
||||
|
||||
# read buffers:
|
||||
|
||||
# if w_j > 0.0:
|
||||
eta_p = activation_distance[0]
|
||||
eta_v = activation_distance[1]
|
||||
eta_a = activation_distance[2]
|
||||
eta_j = activation_distance[3]
|
||||
|
||||
p_range = p_u - p_l
|
||||
p_l = p_l + (p_range) * eta_p
|
||||
p_u = p_u - (p_range) * eta_p
|
||||
v_l = v_l + (v_u - v_l) * eta_v
|
||||
v_u = v_u - (v_u - v_l) * eta_v
|
||||
a_l = a_l + (a_u - a_l) * eta_a
|
||||
a_u = a_u - (a_u - a_l) * eta_a
|
||||
j_l = j_l + (j_u - j_l) * eta_j
|
||||
j_u = j_u - (j_u - j_l) * eta_j
|
||||
|
||||
# position:
|
||||
if n_w > 0.0:
|
||||
error = wp.cw_mul(vec_weight_local, current_p - target_p)
|
||||
error_length = wp.length(error)
|
||||
c_total = n_w * error_length
|
||||
if error_length > 0.0:
|
||||
g_p = n_w * error / error_length
|
||||
|
||||
# bound cost:
|
||||
# bound cost:
|
||||
bound_delta_p = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
bound_delta_v = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
bound_delta_a = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
bound_delta_j = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
for i in range(dof_template):
|
||||
if current_p[i] < p_l[i]:
|
||||
bound_delta_p[i] = current_p[i] - p_l[i]
|
||||
elif current_p[i] > p_u[i]:
|
||||
bound_delta_p[i] = current_p[i] - p_u[i]
|
||||
if current_v[i] < v_l[i]:
|
||||
bound_delta_v[i] = current_v[i] - v_l[i]
|
||||
elif current_v[i] > v_u[i]:
|
||||
bound_delta_v[i] = current_v[i] - v_u[i]
|
||||
if current_a[i] < a_l[i]:
|
||||
bound_delta_a[i] = current_a[i] - a_l[i]
|
||||
elif current_a[i] > a_u[i]:
|
||||
bound_delta_a[i] = current_a[i] - a_u[i]
|
||||
if current_j[i] < j_l[i]:
|
||||
bound_delta_j[i] = current_j[i] - j_l[i]
|
||||
elif current_j[i] > j_u[i]:
|
||||
bound_delta_j[i] = current_j[i] - j_u[i]
|
||||
delta_p = wp.length(bound_delta_p)
|
||||
if delta_p > 0.0:
|
||||
g_p += w * bound_delta_p / delta_p
|
||||
c_total += w * delta_p
|
||||
delta_v = wp.length(bound_delta_v)
|
||||
if delta_v > 0.0:
|
||||
g_v = b_wv * bound_delta_v / delta_v
|
||||
c_total += b_wv * delta_v
|
||||
delta_a = wp.length(bound_delta_a)
|
||||
if delta_a > 0.0:
|
||||
g_a = b_wa * bound_delta_a / delta_a
|
||||
c_total += b_wa * delta_a
|
||||
delta_j = wp.length(bound_delta_j)
|
||||
if delta_j > 0.0:
|
||||
g_j = b_wj * bound_delta_j / delta_j
|
||||
c_total += b_wj * delta_j
|
||||
|
||||
delta_acc = wp.cw_mul(cspace_weight_local, current_a)
|
||||
# delta_acc = wp.cw_div(delta_acc, a_u - a_l)
|
||||
acc_length = wp.length_sq(delta_acc)
|
||||
s_a = w_a * r_wa * acc_length
|
||||
if acc_length > 0.0:
|
||||
g_a += 2.0 * w_a * r_wa * delta_acc # / acc_length
|
||||
|
||||
delta_jerk = wp.cw_mul(cspace_weight_local, current_j)
|
||||
# delta_jerk = wp.cw_div(delta_jerk, j_u - j_l)
|
||||
jerk_length = wp.length_sq(delta_jerk)
|
||||
s_j = w_j * r_wj * jerk_length
|
||||
if jerk_length > 0.0:
|
||||
g_j += 2.0 * w_j * r_wj * delta_jerk # / jerk_length
|
||||
|
||||
c_total += s_a + s_j
|
||||
|
||||
out_cost[b_id * horizon + h_id] = c_total
|
||||
|
||||
# compute gradient
|
||||
if write_grad == 1:
|
||||
b_addrs = b_id * horizon * dof + h_id * dof
|
||||
for i in range(dof_template):
|
||||
out_grad_p[b_addrs + i] = g_p[i]
|
||||
out_grad_v[b_addrs + i] = g_v[i]
|
||||
out_grad_a[b_addrs + i] = g_a[i]
|
||||
out_grad_j[b_addrs + i] = g_j[i]
|
||||
|
||||
return wp.Kernel(forward_bound_smooth_loop_warp)
|
||||
|
||||
|
||||
@get_cache_fn_decorator()
|
||||
def make_bound_pos_kernel(dof_template: int):
|
||||
|
||||
def forward_bound_pos_loop_warp(
|
||||
pos: wp.array(dtype=wp.float32),
|
||||
retract_config: wp.array(dtype=wp.float32),
|
||||
retract_idx: wp.array(dtype=wp.int32),
|
||||
p_b: wp.array(dtype=wp.float32),
|
||||
weight: wp.array(dtype=wp.float32),
|
||||
activation_distance: wp.array(dtype=wp.float32),
|
||||
null_weight: wp.array(dtype=wp.float32),
|
||||
vec_weight: wp.array(dtype=wp.float32),
|
||||
out_cost: wp.array(dtype=wp.float32),
|
||||
out_grad_p: wp.array(dtype=wp.float32),
|
||||
write_grad: wp.uint8, # this should be a bool
|
||||
batch_size: wp.int32,
|
||||
horizon: wp.int32,
|
||||
dof: wp.int32,
|
||||
):
|
||||
tid = wp.tid()
|
||||
# initialize variables:
|
||||
b_id = wp.int32(0)
|
||||
h_id = wp.int32(0)
|
||||
b_addrs = int(0)
|
||||
|
||||
w = wp.float32(0.0)
|
||||
c_total = wp.float32(0.0)
|
||||
|
||||
# we launch batch * horizon * dof kernels
|
||||
b_id = tid / (horizon)
|
||||
h_id = tid - (b_id * horizon)
|
||||
if b_id >= batch_size or h_id >= horizon:
|
||||
return
|
||||
|
||||
# read weights:
|
||||
eta_p = activation_distance[0]
|
||||
w = weight[0]
|
||||
|
||||
n_w = wp.float32(0.0)
|
||||
n_w = null_weight[0]
|
||||
target_id = wp.int32(0.0)
|
||||
vec_weight_local = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
p_l = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
p_u = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
bound_delta = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
target_p = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
current_p = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
g_p = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
|
||||
target_id = retract_idx[b_id]
|
||||
b_addrs = b_id * horizon * dof + h_id * dof
|
||||
|
||||
for i in range(dof_template):
|
||||
vec_weight_local[i] = vec_weight[i]
|
||||
target_p[i] = retract_config[target_id * dof + i]
|
||||
current_p[i] = pos[b_addrs + i]
|
||||
p_l[i] = p_b[i]
|
||||
p_u[i] = p_b[dof + i]
|
||||
|
||||
p_range = p_u - p_l
|
||||
eta_percent = eta_p * (p_range)
|
||||
p_l += eta_percent
|
||||
p_u -= eta_percent
|
||||
|
||||
# compute retract cost:
|
||||
|
||||
if n_w > 0.0:
|
||||
error = wp.cw_mul(vec_weight_local, current_p - target_p)
|
||||
error_length = wp.length(error)
|
||||
c_total = n_w * error_length
|
||||
if error_length > 0.0:
|
||||
g_p = n_w * error / error_length
|
||||
|
||||
# bound cost:
|
||||
|
||||
for i in range(dof_template):
|
||||
if current_p[i] < p_l[i]:
|
||||
bound_delta[i] = current_p[i] - p_l[i]
|
||||
elif current_p[i] > p_u[i]:
|
||||
bound_delta[i] = current_p[i] - p_u[i]
|
||||
|
||||
delta = wp.length(bound_delta)
|
||||
if delta > 0.0:
|
||||
g_p += w * bound_delta / delta
|
||||
c_total += w * delta
|
||||
|
||||
out_cost[b_id * horizon + h_id] = c_total
|
||||
|
||||
# compute gradient
|
||||
if write_grad == 1:
|
||||
b_addrs = b_id * horizon * dof + h_id * dof
|
||||
for i in range(dof_template):
|
||||
out_grad_p[b_addrs + i] = g_p[i]
|
||||
|
||||
return wp.Kernel(forward_bound_pos_loop_warp)
|
||||
|
||||
@@ -18,7 +18,8 @@ import torch
|
||||
import warp as wp
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.logger import log_error
|
||||
from curobo.util.torch_utils import get_cache_fn_decorator, get_torch_jit_decorator
|
||||
from curobo.util.warp import init_warp
|
||||
|
||||
# Local Folder
|
||||
@@ -37,6 +38,7 @@ class DistType(Enum):
|
||||
class DistCostConfig(CostConfig):
|
||||
dist_type: DistType = DistType.L2
|
||||
use_null_space: bool = False
|
||||
use_l2_kernel: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
return super().__post_init__()
|
||||
@@ -142,6 +144,89 @@ def forward_l2_warp(
|
||||
out_grad_p[b_addrs] = g_p
|
||||
|
||||
|
||||
@get_cache_fn_decorator()
|
||||
def make_l2_kernel(dof_template: int):
|
||||
def forward_l2_loop_warp(
|
||||
pos: wp.array(dtype=wp.float32),
|
||||
target: wp.array(dtype=wp.float32),
|
||||
target_idx: wp.array(dtype=wp.int32),
|
||||
weight: wp.array(dtype=wp.float32),
|
||||
run_weight: wp.array(dtype=wp.float32),
|
||||
vec_weight: wp.array(dtype=wp.float32),
|
||||
out_cost: wp.array(dtype=wp.float32),
|
||||
out_grad_p: wp.array(dtype=wp.float32),
|
||||
write_grad: wp.uint8, # this should be a bool
|
||||
batch_size: wp.int32,
|
||||
horizon: wp.int32,
|
||||
dof: wp.int32,
|
||||
):
|
||||
tid = wp.tid()
|
||||
# initialize variables:
|
||||
b_id = wp.int32(0)
|
||||
h_id = wp.int32(0)
|
||||
b_addrs = wp.int32(0)
|
||||
target_id = wp.int32(0)
|
||||
w = wp.float32(0.0)
|
||||
r_w = wp.float32(0.0)
|
||||
c_total = wp.float32(0.0)
|
||||
|
||||
# we launch batch * horizon * dof kernels
|
||||
b_id = tid / (horizon)
|
||||
h_id = tid - (b_id * horizon)
|
||||
|
||||
if b_id >= batch_size or h_id >= horizon:
|
||||
return
|
||||
|
||||
# read weights:
|
||||
w = weight[0]
|
||||
r_w = run_weight[h_id]
|
||||
w = r_w * w
|
||||
|
||||
if w == 0.0:
|
||||
return
|
||||
# compute cost:
|
||||
b_addrs = b_id * horizon * dof + h_id * dof
|
||||
|
||||
# read buffers:
|
||||
current_position = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
target_position = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
vec_weight_local = wp.vector(dtype=wp.float32, length=dof_template)
|
||||
target_id = target_idx[b_id]
|
||||
target_id = target_id * dof
|
||||
|
||||
for i in range(dof_template):
|
||||
current_position[i] = pos[b_addrs + i]
|
||||
target_position[i] = target[target_id + i]
|
||||
vec_weight_local[i] = vec_weight[i]
|
||||
|
||||
error = wp.cw_mul(vec_weight_local, (current_position - target_position))
|
||||
|
||||
c_length = wp.length(error)
|
||||
if w > 100.0:
|
||||
p_w_alpha = 70.0
|
||||
c_total = w * wp.log2(wp.cosh(p_w_alpha * c_length))
|
||||
g_p = error * (
|
||||
w
|
||||
* p_w_alpha
|
||||
* wp.sinh(p_w_alpha * c_length)
|
||||
/ (c_length * wp.cosh(p_w_alpha * c_length))
|
||||
)
|
||||
else:
|
||||
g_p = w * error
|
||||
if c_length > 0.0:
|
||||
g_p = g_p / c_length
|
||||
c_total = w * c_length
|
||||
|
||||
out_cost[b_id * horizon + h_id] = c_total
|
||||
|
||||
# compute gradient
|
||||
if write_grad == 1:
|
||||
for i in range(dof_template):
|
||||
out_grad_p[b_addrs + i] = g_p[i]
|
||||
|
||||
return wp.Kernel(forward_l2_loop_warp)
|
||||
|
||||
|
||||
# create a bound cost tensor:
|
||||
class L2DistFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@@ -195,6 +280,58 @@ class L2DistFunction(torch.autograd.Function):
|
||||
return p_g, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
# create a bound cost tensor:
|
||||
class L2DistLoopFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
pos,
|
||||
target,
|
||||
target_idx,
|
||||
weight,
|
||||
run_weight,
|
||||
vec_weight,
|
||||
out_cost,
|
||||
out_cost_v,
|
||||
out_gp,
|
||||
l2_dof_kernel,
|
||||
):
|
||||
wp_device = wp.device_from_torch(pos.device)
|
||||
b, h, dof = pos.shape
|
||||
wp.launch(
|
||||
kernel=l2_dof_kernel,
|
||||
dim=b * h,
|
||||
inputs=[
|
||||
wp.from_torch(pos.detach().view(-1).contiguous(), dtype=wp.float32),
|
||||
wp.from_torch(target.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(target_idx.view(-1), dtype=wp.int32),
|
||||
wp.from_torch(weight, dtype=wp.float32),
|
||||
wp.from_torch(run_weight.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(vec_weight.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(out_cost.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(out_gp.view(-1), dtype=wp.float32),
|
||||
pos.requires_grad,
|
||||
b,
|
||||
h,
|
||||
dof,
|
||||
],
|
||||
device=wp_device,
|
||||
stream=wp.stream_from_torch(pos.device),
|
||||
)
|
||||
|
||||
ctx.save_for_backward(out_gp)
|
||||
return out_cost
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out_cost):
|
||||
(p_grad,) = ctx.saved_tensors
|
||||
|
||||
p_g = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
p_g = p_grad
|
||||
return p_g, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class DistCost(CostBase, DistCostConfig):
|
||||
def __init__(self, config: Optional[DistCostConfig] = None):
|
||||
if config is not None:
|
||||
@@ -202,6 +339,8 @@ class DistCost(CostBase, DistCostConfig):
|
||||
CostBase.__init__(self)
|
||||
self._init_post_config()
|
||||
init_warp()
|
||||
if self.use_l2_kernel:
|
||||
self._l2_dof_kernel = make_l2_kernel(self.dof)
|
||||
|
||||
def _init_post_config(self):
|
||||
if self.vec_weight is not None:
|
||||
@@ -210,13 +349,21 @@ class DistCost(CostBase, DistCostConfig):
|
||||
self.vec_weight = self.vec_weight * 0.0 + 1.0
|
||||
|
||||
def update_batch_size(self, batch, horizon, dof):
|
||||
if dof != self.dof:
|
||||
log_error("dof cannot be changed after initializing DistCost")
|
||||
if self._batch_size != batch or self._horizon != horizon or self._dof != dof:
|
||||
self._out_cv_buffer = torch.zeros(
|
||||
(batch, horizon, dof), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
self._out_c_buffer = torch.zeros(
|
||||
(batch, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
self._out_c_buffer = None
|
||||
self._out_cv_buffer = None
|
||||
if self.use_l2_kernel:
|
||||
self._out_c_buffer = torch.zeros(
|
||||
(batch, horizon), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
else:
|
||||
self._out_cv_buffer = torch.zeros(
|
||||
(batch, horizon, dof),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
)
|
||||
|
||||
self._out_g_buffer = torch.zeros(
|
||||
(batch, horizon, dof), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
@@ -293,24 +440,59 @@ class DistCost(CostBase, DistCostConfig):
|
||||
else:
|
||||
raise NotImplementedError("terminal flag needs to be set to true")
|
||||
if self.dist_type == DistType.L2:
|
||||
# print(goal_idx.shape, goal_vec.shape)
|
||||
cost = L2DistFunction.apply(
|
||||
current_vec,
|
||||
goal_vec,
|
||||
goal_idx,
|
||||
self.weight,
|
||||
self._run_weight_vec,
|
||||
self.vec_weight,
|
||||
self._out_c_buffer,
|
||||
self._out_cv_buffer,
|
||||
self._out_g_buffer,
|
||||
)
|
||||
if self.use_l2_kernel:
|
||||
|
||||
cost = L2DistLoopFunction.apply(
|
||||
current_vec,
|
||||
goal_vec,
|
||||
goal_idx,
|
||||
self.weight,
|
||||
self._run_weight_vec,
|
||||
self.vec_weight,
|
||||
self._out_c_buffer,
|
||||
None,
|
||||
self._out_g_buffer,
|
||||
self._l2_dof_kernel,
|
||||
)
|
||||
else:
|
||||
cost = L2DistFunction.apply(
|
||||
current_vec,
|
||||
goal_vec,
|
||||
goal_idx,
|
||||
self.weight,
|
||||
self._run_weight_vec,
|
||||
self.vec_weight,
|
||||
None,
|
||||
self._out_cv_buffer,
|
||||
self._out_g_buffer,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if RETURN_GOAL_DIST:
|
||||
dist_scale = torch.nan_to_num(
|
||||
1.0 / torch.sqrt((self.weight * self._run_weight_vec)), 0.0
|
||||
)
|
||||
return cost, cost * dist_scale
|
||||
|
||||
if self.use_l2_kernel:
|
||||
distance = weight_cost_to_l2_jit(cost, self.weight, self._run_weight_vec)
|
||||
else:
|
||||
distance = squared_cost_to_l2_jit(cost, self.weight, self._run_weight_vec)
|
||||
return cost, distance
|
||||
|
||||
return cost
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def squared_cost_to_l2_jit(cost, weight, run_weight_vec):
|
||||
weight_inv = weight * run_weight_vec
|
||||
weight_inv = 1.0 / weight_inv
|
||||
weight_inv = torch.nan_to_num(weight_inv, 0.0)
|
||||
distance = torch.sqrt(cost * weight_inv)
|
||||
return distance
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def weight_cost_to_l2_jit(cost, weight, run_weight_vec):
|
||||
weight_inv = weight * run_weight_vec
|
||||
weight_inv = 1.0 / weight_inv
|
||||
weight_inv = torch.nan_to_num(weight_inv, 0.0)
|
||||
distance = cost * weight_inv
|
||||
return distance
|
||||
|
||||
@@ -522,27 +522,20 @@ class TensorStepPositionCliqueKernel(TensorStepBase):
|
||||
self._u_grad,
|
||||
)
|
||||
if self._filter_velocity:
|
||||
out_state_seq.aux_data["raw_velocity"] = out_state_seq.velocity
|
||||
out_state_seq.velocity = self.filter_signal(out_state_seq.velocity)
|
||||
|
||||
if self._filter_acceleration:
|
||||
out_state_seq.aux_data["raw_acceleration"] = out_state_seq.acceleration
|
||||
out_state_seq.acceleration = self.filter_signal(out_state_seq.acceleration)
|
||||
|
||||
if self._filter_jerk:
|
||||
out_state_seq.aux_data["raw_jerk"] = out_state_seq.jerk
|
||||
out_state_seq.jerk = self.filter_signal(out_state_seq.jerk)
|
||||
return out_state_seq
|
||||
|
||||
def filter_signal(self, signal: torch.Tensor):
|
||||
return filter_signal_jit(signal, self._sma_kernel)
|
||||
b, h, dof = signal.shape
|
||||
new_signal = (
|
||||
self._sma(
|
||||
signal.transpose(-1, -2).reshape(b * dof, 1, h), self._sma_kernel, padding="same"
|
||||
)
|
||||
.view(b, dof, h)
|
||||
.transpose(-1, -2)
|
||||
.contiguous()
|
||||
)
|
||||
return new_signal
|
||||
|
||||
|
||||
@get_torch_jit_decorator(force_jit=True)
|
||||
|
||||
Reference in New Issue
Block a user