Improved precision, quality and js planner.
This commit is contained in:
@@ -128,12 +128,12 @@ def forward_l2_warp(
|
||||
target_p = target[target_id]
|
||||
error = c_p - target_p
|
||||
|
||||
if r_w >= 1.0 and w > 100.0:
|
||||
c_total = w * wp.log2(wp.cosh(50.0 * error))
|
||||
g_p = w * 50.0 * wp.sinh(50.0 * error) / (wp.cosh(50.0 * error))
|
||||
else:
|
||||
c_total = w * error * error
|
||||
g_p = 2.0 * w * error
|
||||
# if r_w >= 1.0 and w > 100.0:
|
||||
# c_total = w * wp.log2(wp.cosh(10.0 * error))
|
||||
# g_p = w * 10.0 * wp.sinh(10.0 * error) / (wp.cosh(10.0 * error))
|
||||
# else:
|
||||
c_total = w * error * error
|
||||
g_p = 2.0 * w * error
|
||||
|
||||
out_cost[b_addrs] = c_total
|
||||
|
||||
@@ -159,8 +159,7 @@ class L2DistFunction(torch.autograd.Function):
|
||||
):
|
||||
wp_device = wp.device_from_torch(pos.device)
|
||||
b, h, dof = pos.shape
|
||||
# print(target)
|
||||
|
||||
requires_grad = pos.requires_grad
|
||||
wp.launch(
|
||||
kernel=forward_l2_warp,
|
||||
dim=b * h * dof,
|
||||
@@ -173,7 +172,7 @@ class L2DistFunction(torch.autograd.Function):
|
||||
wp.from_torch(vec_weight.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(out_cost_v.view(-1), dtype=wp.float32),
|
||||
wp.from_torch(out_gp.view(-1), dtype=wp.float32),
|
||||
pos.requires_grad,
|
||||
requires_grad,
|
||||
b,
|
||||
h,
|
||||
dof,
|
||||
@@ -181,11 +180,8 @@ class L2DistFunction(torch.autograd.Function):
|
||||
device=wp_device,
|
||||
stream=wp.stream_from_torch(pos.device),
|
||||
)
|
||||
# cost = torch.linalg.norm(out_cost_v, dim=-1)
|
||||
# if pos.requires_grad:
|
||||
# out_gp = out_gp * torch.nan_to_num( 1.0/cost.unsqueeze(-1), 0.0)
|
||||
cost = torch.sum(out_cost_v, dim=-1)
|
||||
|
||||
cost = torch.sum(out_cost_v, dim=-1)
|
||||
ctx.save_for_backward(out_gp)
|
||||
return cost
|
||||
|
||||
@@ -277,7 +273,11 @@ class DistCost(CostBase, DistCostConfig):
|
||||
self._run_weight_vec[:, :-1] *= self.run_weight
|
||||
cost = self._run_weight_vec * dist
|
||||
if RETURN_GOAL_DIST:
|
||||
return cost, dist / self.weight
|
||||
dist_scale = torch.nan_to_num(
|
||||
1.0 / torch.sqrt((self.weight * self._run_weight_vec)), 0.0
|
||||
)
|
||||
|
||||
return cost, dist * dist_scale
|
||||
return cost
|
||||
|
||||
def forward_target_idx(self, goal_vec, current_vec, goal_idx, RETURN_GOAL_DIST=False):
|
||||
@@ -292,7 +292,6 @@ class DistCost(CostBase, DistCostConfig):
|
||||
self._run_weight_vec[:, :-1] *= self.run_weight
|
||||
else:
|
||||
raise NotImplementedError("terminal flag needs to be set to true")
|
||||
|
||||
if self.dist_type == DistType.L2:
|
||||
# print(goal_idx.shape, goal_vec.shape)
|
||||
cost = L2DistFunction.apply(
|
||||
@@ -306,11 +305,12 @@ class DistCost(CostBase, DistCostConfig):
|
||||
self._out_cv_buffer,
|
||||
self._out_g_buffer,
|
||||
)
|
||||
# cost = torch.linalg.norm(cost, dim=-1)
|
||||
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
# print(cost.shape, cost[:,-1])
|
||||
if RETURN_GOAL_DIST:
|
||||
return cost, (cost / torch.sqrt((self.weight * self._run_weight_vec)))
|
||||
dist_scale = torch.nan_to_num(
|
||||
1.0 / torch.sqrt((self.weight * self._run_weight_vec)), 0.0
|
||||
)
|
||||
return cost, cost * dist_scale
|
||||
return cost
|
||||
|
||||
Reference in New Issue
Block a user