Significantly improved convergence for mesh and cuboid, new ESDF collision.
This commit is contained in:
@@ -25,6 +25,7 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import JointState
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.sample_lib import bspline
|
||||
from curobo.util.torch_utils import get_torch_jit_decorator
|
||||
from curobo.util.warp_interpolation import get_cuda_linear_interpolation
|
||||
|
||||
|
||||
@@ -114,7 +115,7 @@ def get_spline_interpolated_trajectory(raw_traj, des_horizon, degree=5):
|
||||
|
||||
for i in range(cpu_traj.shape[-1]):
|
||||
retimed_traj[:, i] = bspline(cpu_traj[:, i], n=des_horizon, degree=degree)
|
||||
retimed_traj = retimed_traj.to(**vars(tensor_args))
|
||||
retimed_traj = retimed_traj.to(**(tensor_args.as_torch_dict()))
|
||||
return retimed_traj
|
||||
|
||||
|
||||
@@ -385,7 +386,7 @@ def get_interpolated_trajectory(
|
||||
kind=kind,
|
||||
last_step=des_horizon,
|
||||
)
|
||||
retimed_traj = retimed_traj.to(**vars(tensor_args))
|
||||
retimed_traj = retimed_traj.to(**(tensor_args.as_torch_dict()))
|
||||
out_traj_state.position[b, :interpolation_steps, :] = retimed_traj
|
||||
out_traj_state.position[b, interpolation_steps:, :] = retimed_traj[
|
||||
interpolation_steps - 1 : interpolation_steps, :
|
||||
@@ -438,7 +439,39 @@ def linear_smooth(
|
||||
return y_new
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def calculate_dt_fixed(
|
||||
vel: torch.Tensor,
|
||||
acc: torch.Tensor,
|
||||
jerk: torch.Tensor,
|
||||
max_vel: torch.Tensor,
|
||||
max_acc: torch.Tensor,
|
||||
max_jerk: torch.Tensor,
|
||||
raw_dt: torch.Tensor,
|
||||
interpolation_dt: float,
|
||||
):
|
||||
# compute scaled dt:
|
||||
max_v_arr = torch.max(torch.abs(vel), dim=-2)[0] # output is batch, dof
|
||||
|
||||
max_acc_arr = torch.max(torch.abs(acc), dim=-2)[0]
|
||||
max_jerk_arr = torch.max(torch.abs(jerk), dim=-2)[0]
|
||||
|
||||
vel_scale_dt = (max_v_arr) / (max_vel.view(1, max_v_arr.shape[-1])) # batch,dof
|
||||
acc_scale_dt = max_acc_arr / (max_acc.view(1, max_acc_arr.shape[-1]))
|
||||
jerk_scale_dt = max_jerk_arr / (max_jerk.view(1, max_jerk_arr.shape[-1]))
|
||||
|
||||
dt_score_vel = raw_dt * torch.max(vel_scale_dt, dim=-1)[0] # batch, 1
|
||||
dt_score_acc = raw_dt * torch.sqrt((torch.max(acc_scale_dt, dim=-1)[0]))
|
||||
dt_score_jerk = raw_dt * torch.pow((torch.max(jerk_scale_dt, dim=-1)[0]), 1 / 3)
|
||||
dt_score = torch.maximum(dt_score_vel, dt_score_acc)
|
||||
dt_score = torch.maximum(dt_score, dt_score_jerk)
|
||||
dt_score = torch.clamp(dt_score, interpolation_dt, raw_dt)
|
||||
# NOTE: this dt score is not dt, rather a scaling to convert velocity, acc, jerk that was
|
||||
# computed with raw_dt to a new dt
|
||||
return dt_score
|
||||
|
||||
|
||||
@get_torch_jit_decorator(force_jit=True)
|
||||
def calculate_dt(
|
||||
vel: torch.Tensor,
|
||||
acc: torch.Tensor,
|
||||
@@ -470,7 +503,7 @@ def calculate_dt(
|
||||
return dt_score
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator(force_jit=True)
|
||||
def calculate_dt_no_clamp(
|
||||
vel: torch.Tensor,
|
||||
acc: torch.Tensor,
|
||||
@@ -497,7 +530,7 @@ def calculate_dt_no_clamp(
|
||||
return dt_score
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@get_torch_jit_decorator()
|
||||
def calculate_tsteps(
|
||||
vel: torch.Tensor,
|
||||
acc: torch.Tensor,
|
||||
@@ -506,13 +539,15 @@ def calculate_tsteps(
|
||||
max_vel: torch.Tensor,
|
||||
max_acc: torch.Tensor,
|
||||
max_jerk: torch.Tensor,
|
||||
raw_dt: float,
|
||||
raw_dt: torch.Tensor,
|
||||
min_dt: float,
|
||||
horizon: int,
|
||||
optimize_dt: bool = True,
|
||||
):
|
||||
# compute scaled dt:
|
||||
opt_dt = calculate_dt(vel, acc, jerk, max_vel, max_acc, max_jerk, raw_dt, interpolation_dt)
|
||||
opt_dt = calculate_dt_fixed(
|
||||
vel, acc, jerk, max_vel, max_acc, max_jerk, raw_dt, interpolation_dt
|
||||
)
|
||||
if not optimize_dt:
|
||||
opt_dt[:] = raw_dt
|
||||
traj_steps = (torch.ceil((horizon - 1) * ((opt_dt) / interpolation_dt))).to(dtype=torch.int32)
|
||||
|
||||
Reference in New Issue
Block a user