update to 0.6.2
This commit is contained in:
@@ -340,9 +340,9 @@ class CliqueTensorStepKernel(torch.autograd.Function):
|
||||
grad_out_a,
|
||||
grad_out_j,
|
||||
traj_dt,
|
||||
out_grad_position.shape[0],
|
||||
out_grad_position.shape[1],
|
||||
out_grad_position.shape[2],
|
||||
grad_out_p.shape[0],
|
||||
grad_out_p.shape[1],
|
||||
grad_out_p.shape[2],
|
||||
)
|
||||
return (
|
||||
u_grad,
|
||||
@@ -412,9 +412,9 @@ class CliqueTensorStepIdxKernel(torch.autograd.Function):
|
||||
grad_out_a,
|
||||
grad_out_j,
|
||||
traj_dt,
|
||||
out_grad_position.shape[0],
|
||||
out_grad_position.shape[1],
|
||||
out_grad_position.shape[2],
|
||||
grad_out_p.shape[0],
|
||||
grad_out_p.shape[1],
|
||||
grad_out_p.shape[2],
|
||||
)
|
||||
return (
|
||||
u_grad,
|
||||
@@ -483,9 +483,9 @@ class CliqueTensorStepCentralDifferenceKernel(torch.autograd.Function):
|
||||
grad_out_a.contiguous(),
|
||||
grad_out_j.contiguous(),
|
||||
traj_dt,
|
||||
out_grad_position.shape[0],
|
||||
out_grad_position.shape[1],
|
||||
out_grad_position.shape[2],
|
||||
grad_out_p.shape[0],
|
||||
grad_out_p.shape[1],
|
||||
grad_out_p.shape[2],
|
||||
0,
|
||||
)
|
||||
return (
|
||||
@@ -557,9 +557,9 @@ class CliqueTensorStepIdxCentralDifferenceKernel(torch.autograd.Function):
|
||||
grad_out_a.contiguous(),
|
||||
grad_out_j.contiguous(),
|
||||
traj_dt,
|
||||
out_grad_position.shape[0],
|
||||
out_grad_position.shape[1],
|
||||
out_grad_position.shape[2],
|
||||
grad_out_p.shape[0],
|
||||
grad_out_p.shape[1],
|
||||
grad_out_p.shape[2],
|
||||
0,
|
||||
)
|
||||
return (
|
||||
@@ -626,9 +626,9 @@ class CliqueTensorStepCoalesceKernel(torch.autograd.Function):
|
||||
grad_out_a.transpose(-1, -2).contiguous(),
|
||||
grad_out_j.transpose(-1, -2).contiguous(),
|
||||
traj_dt,
|
||||
out_grad_position.shape[0],
|
||||
out_grad_position.shape[1],
|
||||
out_grad_position.shape[2],
|
||||
grad_out_p.shape[0],
|
||||
grad_out_p.shape[1],
|
||||
grad_out_p.shape[2],
|
||||
)
|
||||
return (
|
||||
u_grad.transpose(-1, -2).contiguous(),
|
||||
@@ -890,16 +890,16 @@ def interpolate_kernel(h, int_steps, tensor_args: TensorDeviceType):
|
||||
return mat
|
||||
|
||||
|
||||
def action_interpolate_kernel(h, int_steps, tensor_args: TensorDeviceType):
|
||||
def action_interpolate_kernel(h, int_steps, tensor_args: TensorDeviceType, offset: int = 4):
|
||||
mat = torch.zeros(
|
||||
((h - 1) * (int_steps), h), device=tensor_args.device, dtype=tensor_args.dtype
|
||||
)
|
||||
delta = torch.arange(0, int_steps - 2, device=tensor_args.device, dtype=tensor_args.dtype) / (
|
||||
int_steps - 1.0 - 2
|
||||
)
|
||||
delta = torch.arange(
|
||||
0, int_steps - offset + 1, device=tensor_args.device, dtype=tensor_args.dtype
|
||||
) / (int_steps - offset)
|
||||
for i in range(h - 1):
|
||||
mat[i * int_steps : i * (int_steps) + int_steps - 1 - 2, i] = delta.flip(0)[1:]
|
||||
mat[i * int_steps : i * (int_steps) + int_steps - 1 - 2, i + 1] = delta[1:]
|
||||
mat[-3:, 1] = 1.0
|
||||
mat[i * int_steps : i * (int_steps) + int_steps - offset, i] = delta.flip(0)[1:]
|
||||
mat[i * int_steps : i * (int_steps) + int_steps - offset, i + 1] = delta[1:]
|
||||
mat[-offset:, 1] = 1.0
|
||||
|
||||
return mat
|
||||
|
||||
Reference in New Issue
Block a user