update to 0.6.2

This commit is contained in:
Balakumar Sundaralingam
2023-12-15 02:01:33 -08:00
parent d85ae41fba
commit 58958bbcce
105 changed files with 2514 additions and 934 deletions

View File

@@ -340,9 +340,9 @@ class CliqueTensorStepKernel(torch.autograd.Function):
grad_out_a,
grad_out_j,
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
grad_out_p.shape[0],
grad_out_p.shape[1],
grad_out_p.shape[2],
)
return (
u_grad,
@@ -412,9 +412,9 @@ class CliqueTensorStepIdxKernel(torch.autograd.Function):
grad_out_a,
grad_out_j,
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
grad_out_p.shape[0],
grad_out_p.shape[1],
grad_out_p.shape[2],
)
return (
u_grad,
@@ -483,9 +483,9 @@ class CliqueTensorStepCentralDifferenceKernel(torch.autograd.Function):
grad_out_a.contiguous(),
grad_out_j.contiguous(),
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
grad_out_p.shape[0],
grad_out_p.shape[1],
grad_out_p.shape[2],
0,
)
return (
@@ -557,9 +557,9 @@ class CliqueTensorStepIdxCentralDifferenceKernel(torch.autograd.Function):
grad_out_a.contiguous(),
grad_out_j.contiguous(),
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
grad_out_p.shape[0],
grad_out_p.shape[1],
grad_out_p.shape[2],
0,
)
return (
@@ -626,9 +626,9 @@ class CliqueTensorStepCoalesceKernel(torch.autograd.Function):
grad_out_a.transpose(-1, -2).contiguous(),
grad_out_j.transpose(-1, -2).contiguous(),
traj_dt,
out_grad_position.shape[0],
out_grad_position.shape[1],
out_grad_position.shape[2],
grad_out_p.shape[0],
grad_out_p.shape[1],
grad_out_p.shape[2],
)
return (
u_grad.transpose(-1, -2).contiguous(),
@@ -890,16 +890,16 @@ def interpolate_kernel(h, int_steps, tensor_args: TensorDeviceType):
return mat
def action_interpolate_kernel(h, int_steps, tensor_args: TensorDeviceType):
def action_interpolate_kernel(h, int_steps, tensor_args: TensorDeviceType, offset: int = 4):
mat = torch.zeros(
((h - 1) * (int_steps), h), device=tensor_args.device, dtype=tensor_args.dtype
)
delta = torch.arange(0, int_steps - 2, device=tensor_args.device, dtype=tensor_args.dtype) / (
int_steps - 1.0 - 2
)
delta = torch.arange(
0, int_steps - offset + 1, device=tensor_args.device, dtype=tensor_args.dtype
) / (int_steps - offset)
for i in range(h - 1):
mat[i * int_steps : i * (int_steps) + int_steps - 1 - 2, i] = delta.flip(0)[1:]
mat[i * int_steps : i * (int_steps) + int_steps - 1 - 2, i + 1] = delta[1:]
mat[-3:, 1] = 1.0
mat[i * int_steps : i * (int_steps) + int_steps - offset, i] = delta.flip(0)[1:]
mat[i * int_steps : i * (int_steps) + int_steps - offset, i + 1] = delta[1:]
mat[-offset:, 1] = 1.0
return mat