Improved precision, quality and js planner.

This commit is contained in:
Balakumar Sundaralingam
2024-04-11 13:19:01 -07:00
parent 774dcfd609
commit d6e600c88c
51 changed files with 2128 additions and 1054 deletions

View File

@@ -355,6 +355,8 @@ class WarpBoundSmoothFunction(torch.autograd.Function):
wp_device = wp.device_from_torch(vel.device)
# assert smooth_weight.shape[0] == 7
b, h, dof = vel.shape
requires_grad = pos.requires_grad
wp.launch(
kernel=forward_bound_smooth_warp,
dim=b * h * dof,
@@ -383,7 +385,7 @@ class WarpBoundSmoothFunction(torch.autograd.Function):
wp.from_torch(out_gv.view(-1), dtype=wp.float32),
wp.from_torch(out_ga.view(-1), dtype=wp.float32),
wp.from_torch(out_gj.view(-1), dtype=wp.float32),
pos.requires_grad,
requires_grad,
b,
h,
dof,
@@ -471,6 +473,7 @@ class WarpBoundFunction(torch.autograd.Function):
):
wp_device = wp.device_from_torch(vel.device)
b, h, dof = vel.shape
requires_grad = pos.requires_grad
wp.launch(
kernel=forward_bound_warp,
dim=b * h * dof,
@@ -494,7 +497,7 @@ class WarpBoundFunction(torch.autograd.Function):
wp.from_torch(out_gv.view(-1), dtype=wp.float32),
wp.from_torch(out_ga.view(-1), dtype=wp.float32),
wp.from_torch(out_gj.view(-1), dtype=wp.float32),
pos.requires_grad,
requires_grad,
b,
h,
dof,
@@ -505,6 +508,7 @@ class WarpBoundFunction(torch.autograd.Function):
ctx.save_for_backward(out_gp, out_gv, out_ga, out_gj)
# out_c = out_cost
# out_c = torch.linalg.norm(out_cost, dim=-1)
out_c = torch.sum(out_cost, dim=-1)
return out_c
@@ -569,11 +573,11 @@ class WarpBoundPosFunction(torch.autograd.Function):
):
wp_device = wp.device_from_torch(pos.device)
b, h, dof = pos.shape
requires_grad = pos.requires_grad
wp.launch(
kernel=forward_bound_pos_warp,
dim=b * h * dof,
inputs=[
# wp.from_torch(pos.detach().view(-1).contiguous(), dtype=wp.float32),
wp.from_torch(pos.detach().view(-1), dtype=wp.float32),
wp.from_torch(retract_config.detach().view(-1), dtype=wp.float32),
wp.from_torch(retract_idx.detach().view(-1), dtype=wp.int32),
@@ -584,7 +588,7 @@ class WarpBoundPosFunction(torch.autograd.Function):
wp.from_torch(vec_weight.view(-1), dtype=wp.float32),
wp.from_torch(out_cost.view(-1), dtype=wp.float32),
wp.from_torch(out_gp.view(-1), dtype=wp.float32),
pos.requires_grad,
requires_grad,
b,
h,
dof,
@@ -685,23 +689,33 @@ def forward_bound_pos_warp(
c_total = n_w * error * error
g_p = 2.0 * n_w * error
# bound cost:
# if c_p < p_l:
# delta = p_l - c_p
# if (delta) > eta_p or eta_p == 0.0:
# c_total += w * (delta - 0.5 * eta_p)
# g_p += -w
# else:
# c_total += w * (0.5 / eta_p) * delta * delta
# g_p += -w * (1.0 / eta_p) * delta
# elif c_p > p_u:
# delta = c_p - p_u
# if (delta) > eta_p or eta_p == 0.0:
# c_total += w * (delta - 0.5 * eta_p)
# g_p += w
# else:
# c_total += w * (0.5 / eta_p) * delta * delta
# g_p += w * (1.0 / eta_p) * delta
# bound cost:
if c_p < p_l:
delta = p_l - c_p
if (delta) > eta_p or eta_p == 0.0:
c_total += w * (delta - 0.5 * eta_p)
g_p += -w
else:
c_total += w * (0.5 / eta_p) * delta * delta
g_p += -w * (1.0 / eta_p) * delta
delta = c_p - p_l
c_total += w * delta * delta
g_p += 2.0 * w * delta
elif c_p > p_u:
delta = c_p - p_u
if (delta) > eta_p or eta_p == 0.0:
c_total += w * (delta - 0.5 * eta_p)
g_p += w
else:
c_total += w * (0.5 / eta_p) * delta * delta
g_p += w * (1.0 / eta_p) * delta
c_total += w * delta * delta
g_p += 2.0 * w * delta
out_cost[b_addrs] = c_total
@@ -811,73 +825,43 @@ def forward_bound_warp(
g_p = 2.0 * n_w * error
# bound cost:
delta = 0.0
if c_p < p_l:
delta = p_l - c_p
if (delta) > eta_p or eta_p == 0.0:
c_total += w * (delta - 0.5 * eta_p)
g_p += -w
else:
c_total += w * (0.5 / eta_p) * delta * delta
g_p += -w * (1.0 / eta_p) * delta
delta = c_p - p_l
elif c_p > p_u:
delta = c_p - p_u
if (delta) > eta_p or eta_p == 0.0:
c_total += w * (delta - 0.5 * eta_p)
g_p += w
else:
c_total += w * (0.5 / eta_p) * delta * delta
g_p += w * (1.0 / eta_p) * delta
c_total += w * delta * delta
g_p += 2.0 * w * delta
# bound cost:
delta = 0.0
if c_v < v_l:
delta = v_l - c_v
if (delta) > eta_v or eta_v == 0.0:
c_total += b_wv * (delta - 0.5 * eta_v)
g_v = -b_wv
else:
c_total += b_wv * (0.5 / eta_v) * delta * delta
g_v = -b_wv * (1.0 / eta_v) * delta
delta = c_v - v_l
elif c_v > v_u:
delta = c_v - v_u
if (delta) > eta_v or eta_v == 0.0:
c_total += b_wv * (delta - 0.5 * eta_v)
g_v = b_wv
else:
c_total += b_wv * (0.5 / eta_v) * delta * delta
g_v = b_wv * (1.0 / eta_v) * delta
c_total += b_wv * delta * delta
g_v = 2.0 * b_wv * delta
delta = 0.0
if c_a < a_l:
delta = a_l - c_a
if (delta) > eta_a or eta_a == 0.0:
c_total += b_wa * (delta - 0.5 * eta_a)
g_a = -b_wa
else:
c_total += b_wa * (0.5 / eta_a) * delta * delta
g_a = -b_wa * (1.0 / eta_a) * delta
delta = c_a - a_l
elif c_a > a_u:
delta = c_a - a_u
if (delta) > eta_a or eta_a == 0.0:
c_total += b_wa * (delta - 0.5 * eta_a)
g_a = b_wa
else:
c_total += b_wa * (0.5 / eta_a) * delta * delta
g_a = b_wa * (1.0 / eta_a) * delta
c_total += b_wa * delta * delta
g_a = b_wa * 2.0 * delta
delta = 0.0
if c_j < j_l:
delta = j_l - c_j
if (delta) > eta_j or eta_j == 0.0:
c_total += b_wj * (delta - 0.5 * eta_j)
g_j = -b_wj
else:
c_total += b_wj * (0.5 / eta_j) * delta * delta
g_j = -b_wj * (1.0 / eta_j) * delta
delta = c_j - j_l
elif c_j > j_u:
delta = c_j - j_u
if (delta) > eta_j or eta_j == 0.0:
c_total += b_wj * (delta - 0.5 * eta_j)
g_j = b_wj
else:
c_total += b_wj * (0.5 / eta_j) * delta * delta
g_j = b_wj * (1.0 / eta_j) * delta
c_total += b_wj * delta * delta
g_j = b_wj * delta * 2.0
out_cost[b_addrs] = c_total
@@ -1031,75 +1015,45 @@ def forward_bound_smooth_warp(
g_p = 2.0 * n_w * error
# bound cost:
# bound cost:
delta = 0.0
if c_p < p_l:
delta = p_l - c_p
if (delta) > eta_p or eta_p == 0.0:
c_total += w * (delta - 0.5 * eta_p)
g_p += -w
else:
c_total += w * (0.5 / eta_p) * delta * delta
g_p += -w * (1.0 / eta_p) * delta
delta = c_p - p_l
elif c_p > p_u:
delta = c_p - p_u
if (delta) > eta_p or eta_p == 0.0:
c_total += w * (delta - 0.5 * eta_p)
g_p += w
else:
c_total += w * (0.5 / eta_p) * delta * delta
g_p += w * (1.0 / eta_p) * delta
c_total += w * delta * delta
g_p += 2.0 * w * delta
# bound cost:
delta = 0.0
if c_v < v_l:
delta = v_l - c_v
if (delta) > eta_v or eta_v == 0.0:
c_total += b_wv * (delta - 0.5 * eta_v)
g_v = -b_wv
else:
c_total += b_wv * (0.5 / eta_v) * delta * delta
g_v = -b_wv * (1.0 / eta_v) * delta
delta = c_v - v_l
elif c_v > v_u:
delta = c_v - v_u
if (delta) > eta_v or eta_v == 0.0:
c_total += b_wv * (delta - 0.5 * eta_v)
g_v = b_wv
else:
c_total += b_wv * (0.5 / eta_v) * delta * delta
g_v = b_wv * (1.0 / eta_v) * delta
c_total += b_wv * delta * delta
g_v = 2.0 * b_wv * delta
delta = 0.0
if c_a < a_l:
delta = a_l - c_a
if (delta) > eta_a or eta_a == 0.0:
c_total += b_wa * (delta - 0.5 * eta_a)
g_a = -b_wa
else:
c_total += b_wa * (0.5 / eta_a) * delta * delta
g_a = -b_wa * (1.0 / eta_a) * delta
delta = c_a - a_l
elif c_a > a_u:
delta = c_a - a_u
if (delta) > eta_a or eta_a == 0.0:
c_total += b_wa * (delta - 0.5 * eta_a)
g_a = b_wa
else:
c_total += b_wa * (0.5 / eta_a) * delta * delta
g_a = b_wa * (1.0 / eta_a) * delta
c_total += b_wa * delta * delta
g_a = b_wa * 2.0 * delta
delta = 0.0
if c_j < j_l:
delta = j_l - c_j
if (delta) > eta_j or eta_j == 0.0:
c_total += b_wj * (delta - 0.5 * eta_j)
g_j = -b_wj
else:
c_total += b_wj * (0.5 / eta_j) * delta * delta
g_j = -b_wj * (1.0 / eta_j) * delta
delta = c_j - j_l
elif c_j > j_u:
delta = c_j - j_u
if (delta) > eta_j or eta_j == 0.0:
c_total += b_wj * (delta - 0.5 * eta_j)
g_j = b_wj
else:
c_total += b_wj * (0.5 / eta_j) * delta * delta
g_j = b_wj * (1.0 / eta_j) * delta
# g_v = -1.0 * g_v
# g_a = -1.0 * g_a
# g_j = - 1.0
c_total += b_wj * delta * delta
g_j = b_wj * delta * 2.0
# do l2 regularization for velocity:
if r_wv < 1.0:
s_v = w_v * r_wv * c_v * c_v

View File

@@ -128,12 +128,12 @@ def forward_l2_warp(
target_p = target[target_id]
error = c_p - target_p
if r_w >= 1.0 and w > 100.0:
c_total = w * wp.log2(wp.cosh(50.0 * error))
g_p = w * 50.0 * wp.sinh(50.0 * error) / (wp.cosh(50.0 * error))
else:
c_total = w * error * error
g_p = 2.0 * w * error
# if r_w >= 1.0 and w > 100.0:
# c_total = w * wp.log2(wp.cosh(10.0 * error))
# g_p = w * 10.0 * wp.sinh(10.0 * error) / (wp.cosh(10.0 * error))
# else:
c_total = w * error * error
g_p = 2.0 * w * error
out_cost[b_addrs] = c_total
@@ -159,8 +159,7 @@ class L2DistFunction(torch.autograd.Function):
):
wp_device = wp.device_from_torch(pos.device)
b, h, dof = pos.shape
# print(target)
requires_grad = pos.requires_grad
wp.launch(
kernel=forward_l2_warp,
dim=b * h * dof,
@@ -173,7 +172,7 @@ class L2DistFunction(torch.autograd.Function):
wp.from_torch(vec_weight.view(-1), dtype=wp.float32),
wp.from_torch(out_cost_v.view(-1), dtype=wp.float32),
wp.from_torch(out_gp.view(-1), dtype=wp.float32),
pos.requires_grad,
requires_grad,
b,
h,
dof,
@@ -181,11 +180,8 @@ class L2DistFunction(torch.autograd.Function):
device=wp_device,
stream=wp.stream_from_torch(pos.device),
)
# cost = torch.linalg.norm(out_cost_v, dim=-1)
# if pos.requires_grad:
# out_gp = out_gp * torch.nan_to_num( 1.0/cost.unsqueeze(-1), 0.0)
cost = torch.sum(out_cost_v, dim=-1)
cost = torch.sum(out_cost_v, dim=-1)
ctx.save_for_backward(out_gp)
return cost
@@ -277,7 +273,11 @@ class DistCost(CostBase, DistCostConfig):
self._run_weight_vec[:, :-1] *= self.run_weight
cost = self._run_weight_vec * dist
if RETURN_GOAL_DIST:
return cost, dist / self.weight
dist_scale = torch.nan_to_num(
1.0 / torch.sqrt((self.weight * self._run_weight_vec)), 0.0
)
return cost, dist * dist_scale
return cost
def forward_target_idx(self, goal_vec, current_vec, goal_idx, RETURN_GOAL_DIST=False):
@@ -292,7 +292,6 @@ class DistCost(CostBase, DistCostConfig):
self._run_weight_vec[:, :-1] *= self.run_weight
else:
raise NotImplementedError("terminal flag needs to be set to true")
if self.dist_type == DistType.L2:
# print(goal_idx.shape, goal_vec.shape)
cost = L2DistFunction.apply(
@@ -306,11 +305,12 @@ class DistCost(CostBase, DistCostConfig):
self._out_cv_buffer,
self._out_g_buffer,
)
# cost = torch.linalg.norm(cost, dim=-1)
else:
raise NotImplementedError()
# print(cost.shape, cost[:,-1])
if RETURN_GOAL_DIST:
return cost, (cost / torch.sqrt((self.weight * self._run_weight_vec)))
dist_scale = torch.nan_to_num(
1.0 / torch.sqrt((self.weight * self._run_weight_vec)), 0.0
)
return cost, cost * dist_scale
return cost

View File

@@ -105,9 +105,9 @@ class PoseCostMetric:
@classmethod
def create_grasp_approach_metric(
cls,
offset_position: float = 0.5,
offset_position: float = 0.1,
linear_axis: int = 2,
tstep_fraction: float = 0.6,
tstep_fraction: float = 0.8,
tensor_args: TensorDeviceType = TensorDeviceType(),
) -> PoseCostMetric:
"""Enables moving to a pregrasp and then locked orientation movement to final grasp.
@@ -203,7 +203,6 @@ class PoseCost(CostBase, PoseCostConfig):
self.offset_waypoint[:3].copy_(offset_rotation)
self.offset_tstep_fraction[:] = offset_tstep_fraction
if self._horizon <= 0:
print(self.weight)
log_error(
"Updating offset waypoint is only possible after initializing motion gen"
+ " run motion_gen.warmup() before adding offset_waypoint"