Improved precision, quality and js planner.

This commit is contained in:
Balakumar Sundaralingam
2024-04-11 13:19:01 -07:00
parent 774dcfd609
commit d6e600c88c
51 changed files with 2128 additions and 1054 deletions

View File

@@ -128,12 +128,12 @@ def forward_l2_warp(
target_p = target[target_id]
error = c_p - target_p
if r_w >= 1.0 and w > 100.0:
c_total = w * wp.log2(wp.cosh(50.0 * error))
g_p = w * 50.0 * wp.sinh(50.0 * error) / (wp.cosh(50.0 * error))
else:
c_total = w * error * error
g_p = 2.0 * w * error
# if r_w >= 1.0 and w > 100.0:
# c_total = w * wp.log2(wp.cosh(10.0 * error))
# g_p = w * 10.0 * wp.sinh(10.0 * error) / (wp.cosh(10.0 * error))
# else:
c_total = w * error * error
g_p = 2.0 * w * error
out_cost[b_addrs] = c_total
@@ -159,8 +159,7 @@ class L2DistFunction(torch.autograd.Function):
):
wp_device = wp.device_from_torch(pos.device)
b, h, dof = pos.shape
# print(target)
requires_grad = pos.requires_grad
wp.launch(
kernel=forward_l2_warp,
dim=b * h * dof,
@@ -173,7 +172,7 @@ class L2DistFunction(torch.autograd.Function):
wp.from_torch(vec_weight.view(-1), dtype=wp.float32),
wp.from_torch(out_cost_v.view(-1), dtype=wp.float32),
wp.from_torch(out_gp.view(-1), dtype=wp.float32),
pos.requires_grad,
requires_grad,
b,
h,
dof,
@@ -181,11 +180,8 @@ class L2DistFunction(torch.autograd.Function):
device=wp_device,
stream=wp.stream_from_torch(pos.device),
)
# cost = torch.linalg.norm(out_cost_v, dim=-1)
# if pos.requires_grad:
# out_gp = out_gp * torch.nan_to_num( 1.0/cost.unsqueeze(-1), 0.0)
cost = torch.sum(out_cost_v, dim=-1)
cost = torch.sum(out_cost_v, dim=-1)
ctx.save_for_backward(out_gp)
return cost
@@ -277,7 +273,11 @@ class DistCost(CostBase, DistCostConfig):
self._run_weight_vec[:, :-1] *= self.run_weight
cost = self._run_weight_vec * dist
if RETURN_GOAL_DIST:
return cost, dist / self.weight
dist_scale = torch.nan_to_num(
1.0 / torch.sqrt((self.weight * self._run_weight_vec)), 0.0
)
return cost, dist * dist_scale
return cost
def forward_target_idx(self, goal_vec, current_vec, goal_idx, RETURN_GOAL_DIST=False):
@@ -292,7 +292,6 @@ class DistCost(CostBase, DistCostConfig):
self._run_weight_vec[:, :-1] *= self.run_weight
else:
raise NotImplementedError("terminal flag needs to be set to true")
if self.dist_type == DistType.L2:
# print(goal_idx.shape, goal_vec.shape)
cost = L2DistFunction.apply(
@@ -306,11 +305,12 @@ class DistCost(CostBase, DistCostConfig):
self._out_cv_buffer,
self._out_g_buffer,
)
# cost = torch.linalg.norm(cost, dim=-1)
else:
raise NotImplementedError()
# print(cost.shape, cost[:,-1])
if RETURN_GOAL_DIST:
return cost, (cost / torch.sqrt((self.weight * self._run_weight_vec)))
dist_scale = torch.nan_to_num(
1.0 / torch.sqrt((self.weight * self._run_weight_vec)), 0.0
)
return cost, cost * dist_scale
return cost