Significantly improved convergence for mesh and cuboid, new ESDF collision.
This commit is contained in:
@@ -21,6 +21,7 @@ import warp as wp
|
||||
from curobo.cuda_robot_model.types import JointLimits
|
||||
from curobo.types.robot import JointState
|
||||
from curobo.types.tensor import T_DOF
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.warp import init_warp
|
||||
|
||||
# Local Folder
|
||||
@@ -267,7 +268,7 @@ class BoundCost(CostBase, BoundCostConfig):
|
||||
return super().update_dt(dt)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def forward_bound_cost(p, lower_bounds, upper_bounds, weight):
|
||||
# c = weight * torch.sum(torch.nn.functional.relu(torch.max(lower_bounds - p, p - upper_bounds)), dim=-1)
|
||||
|
||||
@@ -281,7 +282,7 @@ def forward_bound_cost(p, lower_bounds, upper_bounds, weight):
|
||||
return c
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def forward_all_bound_cost(
|
||||
p,
|
||||
v,
|
||||
|
||||
@@ -18,6 +18,7 @@ import torch
|
||||
import warp as wp
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.warp import init_warp
|
||||
|
||||
# Local Folder
|
||||
@@ -41,32 +42,32 @@ class DistCostConfig(CostConfig):
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def L2_DistCost_jit(vec_weight, disp_vec):
|
||||
return torch.norm(vec_weight * disp_vec, p=2, dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_SQL2_DistCost_jit(vec_weight, disp_vec):
|
||||
return torch.sum(torch.square(vec_weight * disp_vec), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_L1_DistCost_jit(vec_weight, disp_vec):
|
||||
return torch.sum(torch.abs(vec_weight * disp_vec), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def L2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
||||
return torch.norm(weight * vec_weight * (g_vec - c_vec), p=2, dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_SQL2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
||||
return torch.sum(torch.square(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_L1_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
||||
return torch.sum(torch.abs(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ import torch
|
||||
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollision
|
||||
from curobo.rollout.cost.cost_base import CostBase, CostConfig
|
||||
from curobo.rollout.dynamics_model.integration_utils import interpolate_kernel, sum_matrix
|
||||
from curobo.util.logger import log_info
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -48,7 +50,11 @@ class PrimitiveCollisionCostConfig(CostConfig):
|
||||
#: post optimization interpolation to not hit any obstacles.
|
||||
activation_distance: Union[torch.Tensor, float] = 0.0
|
||||
|
||||
#: Setting this flag to true will sum the distance colliding obstacles.
|
||||
sum_collisions: bool = True
|
||||
|
||||
#: Setting this flag to true will sum the distance across spheres of the robot.
|
||||
#: Setting to False will only take the max distance
|
||||
sum_distance: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -103,6 +109,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
self._collision_query_buffer.update_buffer_shape(
|
||||
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
|
||||
)
|
||||
if not self.sum_distance:
|
||||
log_info("sum_distance=False will be slower than sum_distance=True")
|
||||
self.return_loss = True
|
||||
dist = self.sweep_check_fn(
|
||||
robot_spheres_in,
|
||||
self._collision_query_buffer,
|
||||
@@ -115,9 +124,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
return_loss=self.return_loss,
|
||||
)
|
||||
if self.classify:
|
||||
cost = weight_collision(dist, self.weight, self.sum_distance)
|
||||
cost = weight_collision(dist, self.sum_distance)
|
||||
else:
|
||||
cost = weight_distance(dist, self.weight, self.sum_distance)
|
||||
cost = weight_distance(dist, self.sum_distance)
|
||||
return cost
|
||||
|
||||
def sweep_fn(self, robot_spheres_in, env_query_idx: Optional[torch.Tensor] = None):
|
||||
@@ -140,6 +149,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
self._collision_query_buffer.update_buffer_shape(
|
||||
sampled_spheres.shape, self.tensor_args, self.world_coll_checker.collision_types
|
||||
)
|
||||
if not self.sum_distance:
|
||||
log_info("sum_distance=False will be slower than sum_distance=True")
|
||||
self.return_loss = True
|
||||
dist = self.coll_check_fn(
|
||||
sampled_spheres.contiguous(),
|
||||
self._collision_query_buffer,
|
||||
@@ -151,9 +163,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
dist = dist.view(batch_size, new_horizon, n_spheres)
|
||||
|
||||
if self.classify:
|
||||
cost = weight_sweep_collision(self.int_sum_mat, dist, self.weight, self.sum_distance)
|
||||
cost = weight_sweep_collision(self.int_sum_mat, dist, self.sum_distance)
|
||||
else:
|
||||
cost = weight_sweep_distance(self.int_sum_mat, dist, self.weight, self.sum_distance)
|
||||
cost = weight_sweep_distance(self.int_sum_mat, dist, self.sum_distance)
|
||||
|
||||
return cost
|
||||
|
||||
@@ -161,6 +173,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
self._collision_query_buffer.update_buffer_shape(
|
||||
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
|
||||
)
|
||||
if not self.sum_distance:
|
||||
log_info("sum_distance=False will be slower than sum_distance=True")
|
||||
self.return_loss = True
|
||||
dist = self.coll_check_fn(
|
||||
robot_spheres_in,
|
||||
self._collision_query_buffer,
|
||||
@@ -168,12 +183,13 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
env_query_idx=env_query_idx,
|
||||
activation_distance=self.activation_distance,
|
||||
return_loss=self.return_loss,
|
||||
sum_collisions=self.sum_collisions,
|
||||
)
|
||||
|
||||
if self.classify:
|
||||
cost = weight_collision(dist, self.weight, self.sum_distance)
|
||||
cost = weight_collision(dist, self.sum_distance)
|
||||
else:
|
||||
cost = weight_distance(dist, self.weight, self.sum_distance)
|
||||
cost = weight_distance(dist, self.sum_distance)
|
||||
return cost
|
||||
|
||||
def update_dt(self, dt: Union[float, torch.Tensor]):
|
||||
@@ -184,31 +200,43 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
return self._collision_query_buffer.get_gradient_buffer()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_sweep_distance(int_mat, dist, weight, sum_cost: bool):
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
@get_torch_jit_decorator()
|
||||
def weight_sweep_distance(int_mat, dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
dist = dist @ int_mat
|
||||
return dist
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_sweep_collision(int_mat, dist, weight, sum_cost: bool):
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
@get_torch_jit_decorator()
|
||||
def weight_sweep_collision(int_mat, dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
|
||||
dist = torch.where(dist > 0, dist + 1.0, dist)
|
||||
dist = dist @ int_mat
|
||||
return dist
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_distance(dist, weight, sum_cost: bool):
|
||||
@get_torch_jit_decorator()
|
||||
def weight_distance(dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
return dist
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_collision(dist, weight, sum_cost: bool):
|
||||
@get_torch_jit_decorator()
|
||||
def weight_collision(dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
|
||||
dist = torch.where(dist > 0, dist + 1.0, dist)
|
||||
return dist
|
||||
|
||||
@@ -17,6 +17,7 @@ import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.rollout.dynamics_model.kinematic_model import TimeTrajConfig
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .cost_base import CostBase, CostConfig
|
||||
@@ -72,7 +73,7 @@ class StopCost(CostBase, StopCostConfig):
|
||||
return cost
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def velocity_cost(vels, weight, max_vel):
|
||||
vel_abs = torch.abs(vels)
|
||||
vel_abs = torch.nn.functional.relu(vel_abs - max_vel[: vels.shape[1]])
|
||||
|
||||
@@ -13,11 +13,14 @@
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .cost_base import CostBase, CostConfig
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def st_cost(ee_pos_batch, vec_weight, weight):
|
||||
ee_plus_one = torch.roll(ee_pos_batch, 1, dims=1)
|
||||
|
||||
|
||||
@@ -11,11 +11,14 @@
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .cost_base import CostBase
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def squared_sum(cost: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
# return weight * torch.square(torch.linalg.norm(cost, dim=-1, ord=1))
|
||||
# return weight * torch.sum(torch.square(cost), dim=-1)
|
||||
@@ -24,7 +27,7 @@ def squared_sum(cost: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sum(torch.square(cost) * weight, dim=-1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def run_squared_sum(
|
||||
cost: torch.Tensor, weight: torch.Tensor, run_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
@@ -35,13 +38,13 @@ def run_squared_sum(
|
||||
# return torch.sum(torch.square(cost), dim=-1) * weight * run_weight
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def backward_squared_sum(cost_vec, w):
|
||||
return 2.0 * w * cost_vec # * g_out.unsqueeze(-1)
|
||||
# return w * g_out.unsqueeze(-1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def backward_run_squared_sum(cost_vec, w, r_w):
|
||||
return 2.0 * w * r_w.unsqueeze(-1) * cost_vec # * g_out.unsqueeze(-1)
|
||||
# return w * r_w.unsqueeze(-1) * cost_vec * g_out.unsqueeze(-1)
|
||||
|
||||
Reference in New Issue
Block a user