Significantly improved convergence for mesh and cuboid, new ESDF collision.
This commit is contained in:
@@ -14,6 +14,9 @@ from typing import List
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
def check_tensor_shapes(new_tensor: torch.Tensor, mem_tensor: torch.Tensor):
|
||||
if not isinstance(mem_tensor, torch.Tensor):
|
||||
@@ -65,13 +68,19 @@ def clone_if_not_none(x):
|
||||
return None
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def cat_sum(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=0)
|
||||
return cat_tensor
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def cat_sum_horizon(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=(0, -1))
|
||||
return cat_tensor
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def cat_max(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.max(torch.stack(tensor_list, dim=0), dim=0)[0]
|
||||
return cat_tensor
|
||||
@@ -85,7 +94,7 @@ def tensor_repeat_seeds(tensor, num_seeds):
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fd_tensor(p: torch.Tensor, dt: torch.Tensor):
|
||||
out = ((torch.roll(p, -1, -2) - p) * (1 / dt).unsqueeze(-1))[..., :-1, :]
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user