Significantly improved convergence for mesh and cuboid, new ESDF collision.
This commit is contained in:
@@ -40,7 +40,7 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import CSpaceConfig, RobotConfig
|
||||
from curobo.types.state import JointState
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.tensor_util import cat_sum
|
||||
from curobo.util.tensor_util import cat_sum, cat_sum_horizon
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -104,10 +104,10 @@ class ArmCostConfig:
|
||||
|
||||
@dataclass
|
||||
class ArmBaseConfig(RolloutConfig):
|
||||
model_cfg: KinematicModelConfig
|
||||
cost_cfg: ArmCostConfig
|
||||
constraint_cfg: ArmCostConfig
|
||||
convergence_cfg: ArmCostConfig
|
||||
model_cfg: Optional[KinematicModelConfig] = None
|
||||
cost_cfg: Optional[ArmCostConfig] = None
|
||||
constraint_cfg: Optional[ArmCostConfig] = None
|
||||
convergence_cfg: Optional[ArmCostConfig] = None
|
||||
world_coll_checker: Optional[WorldCollision] = None
|
||||
|
||||
@staticmethod
|
||||
@@ -322,7 +322,9 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
self.null_convergence = DistCost(self.convergence_cfg.null_space_cfg)
|
||||
|
||||
# set start state:
|
||||
start_state = torch.randn((1, self.dynamics_model.d_state), **vars(self.tensor_args))
|
||||
start_state = torch.randn(
|
||||
(1, self.dynamics_model.d_state), **(self.tensor_args.as_torch_dict())
|
||||
)
|
||||
self._start_state = JointState(
|
||||
position=start_state[:, : self.dynamics_model.d_dof],
|
||||
velocity=start_state[:, : self.dynamics_model.d_dof],
|
||||
@@ -366,9 +368,11 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
)
|
||||
cost_list.append(coll_cost)
|
||||
if return_list:
|
||||
|
||||
return cost_list
|
||||
cost = cat_sum(cost_list)
|
||||
if self.sum_horizon:
|
||||
cost = cat_sum_horizon(cost_list)
|
||||
else:
|
||||
cost = cat_sum(cost_list)
|
||||
return cost
|
||||
|
||||
def constraint_fn(
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#
|
||||
# Standard Library
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
@@ -29,8 +29,9 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import RobotConfig
|
||||
from curobo.types.tensor import T_BValue_float, T_BValue_int
|
||||
from curobo.util.helpers import list_idx_if_not_none
|
||||
from curobo.util.logger import log_error, log_info
|
||||
from curobo.util.tensor_util import cat_max, cat_sum
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.tensor_util import cat_max
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .arm_base import ArmBase, ArmBaseConfig, ArmCostConfig
|
||||
@@ -145,7 +146,7 @@ class ArmReacherConfig(ArmBaseConfig):
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def _compute_g_dist_jit(rot_err_norm, goal_dist):
|
||||
# goal_cost = goal_cost.view(cost.shape)
|
||||
# rot_err_norm = rot_err_norm.view(cost.shape)
|
||||
@@ -319,7 +320,12 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
g_dist,
|
||||
)
|
||||
cost_list.append(z_vel)
|
||||
cost = cat_sum(cost_list)
|
||||
with profiler.record_function("cat_sum"):
|
||||
if self.sum_horizon:
|
||||
cost = cat_sum_horizon_reacher(cost_list)
|
||||
else:
|
||||
cost = cat_sum_reacher(cost_list)
|
||||
|
||||
return cost
|
||||
|
||||
def convergence_fn(
|
||||
@@ -466,3 +472,15 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
)
|
||||
for x in pose_costs
|
||||
]
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def cat_sum_reacher(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=0)
|
||||
return cat_tensor
|
||||
|
||||
|
||||
@get_torch_jit_decorator()
|
||||
def cat_sum_horizon_reacher(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=(0, -1))
|
||||
return cat_tensor
|
||||
|
||||
@@ -21,6 +21,7 @@ import warp as wp
|
||||
from curobo.cuda_robot_model.types import JointLimits
|
||||
from curobo.types.robot import JointState
|
||||
from curobo.types.tensor import T_DOF
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.warp import init_warp
|
||||
|
||||
# Local Folder
|
||||
@@ -267,7 +268,7 @@ class BoundCost(CostBase, BoundCostConfig):
|
||||
return super().update_dt(dt)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def forward_bound_cost(p, lower_bounds, upper_bounds, weight):
|
||||
# c = weight * torch.sum(torch.nn.functional.relu(torch.max(lower_bounds - p, p - upper_bounds)), dim=-1)
|
||||
|
||||
@@ -281,7 +282,7 @@ def forward_bound_cost(p, lower_bounds, upper_bounds, weight):
|
||||
return c
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def forward_all_bound_cost(
|
||||
p,
|
||||
v,
|
||||
|
||||
@@ -18,6 +18,7 @@ import torch
|
||||
import warp as wp
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.warp import init_warp
|
||||
|
||||
# Local Folder
|
||||
@@ -41,32 +42,32 @@ class DistCostConfig(CostConfig):
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def L2_DistCost_jit(vec_weight, disp_vec):
|
||||
return torch.norm(vec_weight * disp_vec, p=2, dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_SQL2_DistCost_jit(vec_weight, disp_vec):
|
||||
return torch.sum(torch.square(vec_weight * disp_vec), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_L1_DistCost_jit(vec_weight, disp_vec):
|
||||
return torch.sum(torch.abs(vec_weight * disp_vec), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def L2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
||||
return torch.norm(weight * vec_weight * (g_vec - c_vec), p=2, dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_SQL2_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
||||
return torch.sum(torch.square(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def fwd_L1_DistCost_target_jit(vec_weight, g_vec, c_vec, weight):
|
||||
return torch.sum(torch.abs(weight * vec_weight * (g_vec - c_vec)), dim=-1, keepdim=False)
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ import torch
|
||||
from curobo.geom.sdf.world import CollisionQueryBuffer, WorldCollision
|
||||
from curobo.rollout.cost.cost_base import CostBase, CostConfig
|
||||
from curobo.rollout.dynamics_model.integration_utils import interpolate_kernel, sum_matrix
|
||||
from curobo.util.logger import log_info
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -48,7 +50,11 @@ class PrimitiveCollisionCostConfig(CostConfig):
|
||||
#: post optimization interpolation to not hit any obstacles.
|
||||
activation_distance: Union[torch.Tensor, float] = 0.0
|
||||
|
||||
#: Setting this flag to true will sum the distance colliding obstacles.
|
||||
sum_collisions: bool = True
|
||||
|
||||
#: Setting this flag to true will sum the distance across spheres of the robot.
|
||||
#: Setting to False will only take the max distance
|
||||
sum_distance: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -103,6 +109,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
self._collision_query_buffer.update_buffer_shape(
|
||||
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
|
||||
)
|
||||
if not self.sum_distance:
|
||||
log_info("sum_distance=False will be slower than sum_distance=True")
|
||||
self.return_loss = True
|
||||
dist = self.sweep_check_fn(
|
||||
robot_spheres_in,
|
||||
self._collision_query_buffer,
|
||||
@@ -115,9 +124,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
return_loss=self.return_loss,
|
||||
)
|
||||
if self.classify:
|
||||
cost = weight_collision(dist, self.weight, self.sum_distance)
|
||||
cost = weight_collision(dist, self.sum_distance)
|
||||
else:
|
||||
cost = weight_distance(dist, self.weight, self.sum_distance)
|
||||
cost = weight_distance(dist, self.sum_distance)
|
||||
return cost
|
||||
|
||||
def sweep_fn(self, robot_spheres_in, env_query_idx: Optional[torch.Tensor] = None):
|
||||
@@ -140,6 +149,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
self._collision_query_buffer.update_buffer_shape(
|
||||
sampled_spheres.shape, self.tensor_args, self.world_coll_checker.collision_types
|
||||
)
|
||||
if not self.sum_distance:
|
||||
log_info("sum_distance=False will be slower than sum_distance=True")
|
||||
self.return_loss = True
|
||||
dist = self.coll_check_fn(
|
||||
sampled_spheres.contiguous(),
|
||||
self._collision_query_buffer,
|
||||
@@ -151,9 +163,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
dist = dist.view(batch_size, new_horizon, n_spheres)
|
||||
|
||||
if self.classify:
|
||||
cost = weight_sweep_collision(self.int_sum_mat, dist, self.weight, self.sum_distance)
|
||||
cost = weight_sweep_collision(self.int_sum_mat, dist, self.sum_distance)
|
||||
else:
|
||||
cost = weight_sweep_distance(self.int_sum_mat, dist, self.weight, self.sum_distance)
|
||||
cost = weight_sweep_distance(self.int_sum_mat, dist, self.sum_distance)
|
||||
|
||||
return cost
|
||||
|
||||
@@ -161,6 +173,9 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
self._collision_query_buffer.update_buffer_shape(
|
||||
robot_spheres_in.shape, self.tensor_args, self.world_coll_checker.collision_types
|
||||
)
|
||||
if not self.sum_distance:
|
||||
log_info("sum_distance=False will be slower than sum_distance=True")
|
||||
self.return_loss = True
|
||||
dist = self.coll_check_fn(
|
||||
robot_spheres_in,
|
||||
self._collision_query_buffer,
|
||||
@@ -168,12 +183,13 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
env_query_idx=env_query_idx,
|
||||
activation_distance=self.activation_distance,
|
||||
return_loss=self.return_loss,
|
||||
sum_collisions=self.sum_collisions,
|
||||
)
|
||||
|
||||
if self.classify:
|
||||
cost = weight_collision(dist, self.weight, self.sum_distance)
|
||||
cost = weight_collision(dist, self.sum_distance)
|
||||
else:
|
||||
cost = weight_distance(dist, self.weight, self.sum_distance)
|
||||
cost = weight_distance(dist, self.sum_distance)
|
||||
return cost
|
||||
|
||||
def update_dt(self, dt: Union[float, torch.Tensor]):
|
||||
@@ -184,31 +200,43 @@ class PrimitiveCollisionCost(CostBase, PrimitiveCollisionCostConfig):
|
||||
return self._collision_query_buffer.get_gradient_buffer()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_sweep_distance(int_mat, dist, weight, sum_cost: bool):
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
@get_torch_jit_decorator()
|
||||
def weight_sweep_distance(int_mat, dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
dist = dist @ int_mat
|
||||
return dist
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_sweep_collision(int_mat, dist, weight, sum_cost: bool):
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
@get_torch_jit_decorator()
|
||||
def weight_sweep_collision(int_mat, dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
|
||||
dist = torch.where(dist > 0, dist + 1.0, dist)
|
||||
dist = dist @ int_mat
|
||||
return dist
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_distance(dist, weight, sum_cost: bool):
|
||||
@get_torch_jit_decorator()
|
||||
def weight_distance(dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
return dist
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weight_collision(dist, weight, sum_cost: bool):
|
||||
@get_torch_jit_decorator()
|
||||
def weight_collision(dist, sum_cost: bool):
|
||||
if sum_cost:
|
||||
dist = torch.sum(dist, dim=-1)
|
||||
else:
|
||||
dist = torch.max(dist, dim=-1)[0]
|
||||
|
||||
dist = torch.where(dist > 0, dist + 1.0, dist)
|
||||
return dist
|
||||
|
||||
@@ -17,6 +17,7 @@ import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.rollout.dynamics_model.kinematic_model import TimeTrajConfig
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .cost_base import CostBase, CostConfig
|
||||
@@ -72,7 +73,7 @@ class StopCost(CostBase, StopCostConfig):
|
||||
return cost
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def velocity_cost(vels, weight, max_vel):
|
||||
vel_abs = torch.abs(vels)
|
||||
vel_abs = torch.nn.functional.relu(vel_abs - max_vel[: vels.shape[1]])
|
||||
|
||||
@@ -13,11 +13,14 @@
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .cost_base import CostBase, CostConfig
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def st_cost(ee_pos_batch, vec_weight, weight):
|
||||
ee_plus_one = torch.roll(ee_pos_batch, 1, dims=1)
|
||||
|
||||
|
||||
@@ -11,11 +11,14 @@
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .cost_base import CostBase
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def squared_sum(cost: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
# return weight * torch.square(torch.linalg.norm(cost, dim=-1, ord=1))
|
||||
# return weight * torch.sum(torch.square(cost), dim=-1)
|
||||
@@ -24,7 +27,7 @@ def squared_sum(cost: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sum(torch.square(cost) * weight, dim=-1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def run_squared_sum(
|
||||
cost: torch.Tensor, weight: torch.Tensor, run_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
@@ -35,13 +38,13 @@ def run_squared_sum(
|
||||
# return torch.sum(torch.square(cost), dim=-1) * weight * run_weight
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def backward_squared_sum(cost_vec, w):
|
||||
return 2.0 * w * cost_vec # * g_out.unsqueeze(-1)
|
||||
# return w * g_out.unsqueeze(-1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def backward_run_squared_sum(cost_vec, w, r_w):
|
||||
return 2.0 * w * r_w.unsqueeze(-1) * cost_vec # * g_out.unsqueeze(-1)
|
||||
# return w * r_w.unsqueeze(-1) * cost_vec * g_out.unsqueeze(-1)
|
||||
|
||||
@@ -25,6 +25,7 @@ from curobo.curobolib.tensor_step import (
|
||||
)
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import JointState
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
def build_clique_matrix(horizon, dt, device="cpu", dtype=torch.float32):
|
||||
@@ -154,7 +155,7 @@ def build_start_state_mask(horizon, tensor_args: TensorDeviceType):
|
||||
return mask, n_mask
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_jerk(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix=None):
|
||||
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Optional[Tensor]) -> Tensor
|
||||
|
||||
@@ -176,7 +177,7 @@ def tensor_step_jerk(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_m
|
||||
return state_seq
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def euler_integrate(q_0, u, diag_dt, integrate_matrix):
|
||||
# q_new = q_0 + torch.matmul(integrate_matrix, torch.matmul(diag_dt, u))
|
||||
q_new = q_0 + torch.matmul(integrate_matrix, u)
|
||||
@@ -184,7 +185,7 @@ def euler_integrate(q_0, u, diag_dt, integrate_matrix):
|
||||
return q_new
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_acc(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix=None):
|
||||
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Optional[Tensor]) -> Tensor
|
||||
# This is batch,n_dof
|
||||
@@ -207,7 +208,7 @@ def tensor_step_acc(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_ma
|
||||
return state_seq
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_tensor_step_pos_clique_contiguous(pos_act, start_position, mask, n_mask, fd_1, fd_2, fd_3):
|
||||
state_position = (start_position.unsqueeze(1).transpose(1, 2) @ mask.transpose(0, 1)) + (
|
||||
pos_act.transpose(1, 2) @ n_mask.transpose(0, 1)
|
||||
@@ -222,7 +223,7 @@ def jit_tensor_step_pos_clique_contiguous(pos_act, start_position, mask, n_mask,
|
||||
return state_position, state_vel, state_acc, state_jerk
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_tensor_step_pos_clique(pos_act, start_position, mask, n_mask, fd_1, fd_2, fd_3):
|
||||
state_position = mask @ start_position.unsqueeze(1) + n_mask @ pos_act
|
||||
state_vel = fd_1 @ state_position
|
||||
@@ -231,7 +232,7 @@ def jit_tensor_step_pos_clique(pos_act, start_position, mask, n_mask, fd_1, fd_2
|
||||
return state_position, state_vel, state_acc, state_jerk
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_backward_pos_clique(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2, fd_3):
|
||||
p_grad = (
|
||||
grad_p
|
||||
@@ -247,7 +248,7 @@ def jit_backward_pos_clique(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2,
|
||||
return u_grad
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def jit_backward_pos_clique_contiguous(grad_p, grad_v, grad_a, grad_j, n_mask, fd_1, fd_2, fd_3):
|
||||
p_grad = grad_p + (
|
||||
grad_j.transpose(-1, -2) @ fd_3
|
||||
@@ -532,7 +533,7 @@ class CliqueTensorStepIdxCentralDifferenceKernel(torch.autograd.Function):
|
||||
start_position,
|
||||
start_velocity,
|
||||
start_acceleration,
|
||||
start_idx,
|
||||
start_idx.contiguous(),
|
||||
traj_dt,
|
||||
out_position.shape[0],
|
||||
out_position.shape[1],
|
||||
@@ -750,7 +751,7 @@ class AccelerationTensorStepIdxKernel(torch.autograd.Function):
|
||||
return u_grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_pos_clique(
|
||||
state: JointState,
|
||||
act: torch.Tensor,
|
||||
@@ -786,7 +787,7 @@ def step_acc_semi_euler(state, act, diag_dt, n_dofs, integrate_matrix):
|
||||
return state_seq
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_acc_semi_euler(
|
||||
state, act, state_seq, diag_dt, integrate_matrix, integrate_matrix_pos
|
||||
):
|
||||
@@ -806,7 +807,7 @@ def tensor_step_acc_semi_euler(
|
||||
return state_seq
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_vel(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_matrix):
|
||||
# type: (Tensor, Tensor, Tensor, Tensor, int, Tensor, Tensor) -> Tensor
|
||||
|
||||
@@ -830,7 +831,7 @@ def tensor_step_vel(state, act, state_seq, dt_h, n_dofs, integrate_matrix, fd_ma
|
||||
return state_seq
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_pos(state, act, state_seq, fd_matrix):
|
||||
# This is batch,n_dof
|
||||
state_seq.position[:, 0, :] = state.position
|
||||
@@ -850,7 +851,7 @@ def tensor_step_pos(state, act, state_seq, fd_matrix):
|
||||
return state_seq
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# @get_torch_jit_decorator()
|
||||
def tensor_step_pos_ik(act, state_seq):
|
||||
state_seq.position = act
|
||||
return state_seq
|
||||
@@ -869,7 +870,7 @@ def tensor_linspace(start_tensor, end_tensor, steps=10):
|
||||
|
||||
|
||||
def sum_matrix(h, int_steps, tensor_args):
|
||||
sum_mat = torch.zeros(((h - 1) * int_steps, h), **vars(tensor_args))
|
||||
sum_mat = torch.zeros(((h - 1) * int_steps, h), **(tensor_args.as_torch_dict()))
|
||||
for i in range(h - 1):
|
||||
sum_mat[i * int_steps : i * int_steps + int_steps, i] = 1.0
|
||||
# hack:
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch
|
||||
# CuRobo
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import JointState
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
# Local Folder
|
||||
from .integration_utils import (
|
||||
@@ -544,7 +545,7 @@ class TensorStepPositionCliqueKernel(TensorStepBase):
|
||||
return new_signal
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(force_jit=True)
|
||||
def filter_signal_jit(signal, kernel):
|
||||
b, h, dof = signal.shape
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from curobo.util.helpers import list_idx_if_not_none
|
||||
from curobo.util.logger import log_info
|
||||
from curobo.util.sample_lib import HaltonGenerator
|
||||
from curobo.util.tensor_util import copy_tensor
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -298,9 +299,9 @@ class Goal(Sequence):
|
||||
if self.goal_pose is not None:
|
||||
self.goal_pose = self.goal_pose.to(tensor_args)
|
||||
if self.goal_state is not None:
|
||||
self.goal_state = self.goal_state.to(**vars(tensor_args))
|
||||
self.goal_state = self.goal_state.to(**(tensor_args.as_torch_dict()))
|
||||
if self.current_state is not None:
|
||||
self.current_state = self.current_state.to(**vars(tensor_args))
|
||||
self.current_state = self.current_state.to(**(tensor_args.as_torch_dict()))
|
||||
return self
|
||||
|
||||
def copy_(self, goal: Goal, update_idx_buffers: bool = True):
|
||||
@@ -350,6 +351,7 @@ class Goal(Sequence):
|
||||
if ref_buffer is not None:
|
||||
ref_buffer = ref_buffer.copy_(buffer)
|
||||
else:
|
||||
log_info("breaking reference")
|
||||
ref_buffer = buffer.clone()
|
||||
return ref_buffer
|
||||
|
||||
@@ -414,6 +416,7 @@ class Goal(Sequence):
|
||||
@dataclass
|
||||
class RolloutConfig:
|
||||
tensor_args: TensorDeviceType
|
||||
sum_horizon: bool = False
|
||||
|
||||
|
||||
class RolloutBase:
|
||||
@@ -578,7 +581,7 @@ class RolloutBase:
|
||||
return q_js
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def tensor_repeat_seeds(tensor, num_seeds: int):
|
||||
a = (
|
||||
tensor.view(tensor.shape[0], 1, tensor.shape[-1])
|
||||
|
||||
Reference in New Issue
Block a user