update to 0.6.2

This commit is contained in:
Balakumar Sundaralingam
2023-12-15 02:01:33 -08:00
parent d85ae41fba
commit 58958bbcce
105 changed files with 2514 additions and 934 deletions

View File

@@ -28,8 +28,15 @@
#define Y_ROT 4
#define Z_ROT 5
#define X_PRISM_NEG 6
#define Y_PRISM_NEG 7
#define Z_PRISM_NEG 8
#define X_ROT_NEG 9
#define Y_ROT_NEG 10
#define Z_ROT_NEG 11
#define MAX_BATCH_PER_BLOCK 32 // tunable parameter for improving occupancy
#define MAX_BW_BATCH_PER_BLOCK 8 // tunable parameter for improving occupancy
#define MAX_BW_BATCH_PER_BLOCK 16 // tunable parameter for improving occupancy
#define MAX_TOTAL_LINKS \
750 // limited by shared memory size. We need to fit 16 * float32 per link
@@ -238,67 +245,46 @@ __device__ __forceinline__ void prism_fn(const scalar_t *fixedTransform,
}
}
template <typename scalar_t>
__device__ __forceinline__ void xprism_fn(const scalar_t *fixedTransform,
const float angle, const int col_idx,
float *JM) {
__device__ __forceinline__ void update_joint_type_direction(int &j_type, int8_t &axis_sign)
{
// Assume that input j_type >= 0 . Check fixed joint outside of this function.
switch (col_idx) {
case 0:
case 1:
case 2:
fixed_joint_fn(&fixedTransform[col_idx], &JM[0]);
break;
case 3:
JM[0] = fixedTransform[0] * angle + fixedTransform[3]; // FT_0[1];
JM[1] = fixedTransform[M] * angle + fixedTransform[M + 3]; // FT_1[1];
JM[2] =
fixedTransform[M + M] * angle + fixedTransform[M + M + 3]; // FT_2[1];
JM[3] = 1;
break;
// Don't do anything if j_type < 6. j_type range is [0, 11]
// divergence here.
axis_sign = 1;
if (j_type >= 6)
{
j_type -= 6;
axis_sign = -1;
}
}
template <typename scalar_t>
__device__ __forceinline__ void yprism_fn(const scalar_t *fixedTransform,
const float angle, const int col_idx,
float *JM) {
switch (col_idx) {
case 0:
case 1:
case 2:
fixed_joint_fn(&fixedTransform[col_idx], &JM[0]);
break;
case 3:
JM[0] = fixedTransform[1] * angle + fixedTransform[3]; // FT_0[1];
JM[1] = fixedTransform[M + 1] * angle + fixedTransform[M + 3]; // FT_1[1];
JM[2] = fixedTransform[M + M + 1] * angle +
fixedTransform[M + M + 3]; // FT_2[1];
JM[3] = 1;
break;
__device__ __forceinline__ void update_joint_type_direction(int &j_type)
{
// Assume that input j_type >= 0 . Check fixed joint outside of this function.
// Don't do anything if j_type < 6. j_type range is [0, 11]
// divergence here.
if (j_type >= 6)
{
j_type -= 6;
}
}
template <typename scalar_t>
__device__ __forceinline__ void zprism_fn(const scalar_t *fixedTransform,
const float angle, const int col_idx,
float *JM) {
switch (col_idx) {
case 0:
case 1:
case 2:
fixed_joint_fn(&fixedTransform[col_idx], &JM[0]);
break;
case 3:
JM[0] = fixedTransform[2] * angle + fixedTransform[3]; // FT_0[1];
JM[1] = fixedTransform[M + 2] * angle + fixedTransform[M + 3]; // FT_1[1];
JM[2] = fixedTransform[M + M + 2] * angle +
fixedTransform[M + M + 3]; // FT_2[1];
JM[3] = 1;
break;
}
__device__ __forceinline__ void update_axis_direction(
float &angle,
int &j_type)
{
// Assume that input j_type >= 0 . Check fixed joint outside of this function.
// sign should be +ve <= 5 and -ve >5
// j_type range is [0, 11].
// cuda code treats -1.0 * 0.0 as negative. Hence we subtract 6. If in future, -1.0 * 0.0 = +ve,
// then this code should be j_type - 5.
angle = -1 * copysignf(1.0, j_type - 6) * angle;
update_joint_type_direction(j_type);
}
// In the following versions of rot_fn, some non-nan values may become nan as we
@@ -423,7 +409,7 @@ __device__ __forceinline__ void zrot_fn(const scalar_t *fixedTransform,
template <typename psum_t>
__device__ __forceinline__ void
rot_backward_translation(const float3 &vec, float *cumul_mat, float *l_pos,
const float3 &loc_grad, psum_t &grad_q) {
const float3 &loc_grad, psum_t &grad_q, const int8_t axis_sign=1) {
float3 e_pos, j_pos;
@@ -433,136 +419,141 @@ rot_backward_translation(const float3 &vec, float *cumul_mat, float *l_pos,
// compute position gradient:
j_pos = *(float3 *)&l_pos[0] - e_pos; // - e_pos;
scale_cross_sum(vec, j_pos, loc_grad, grad_q); // cross product
float3 scale_grad = axis_sign * loc_grad;
scale_cross_sum(vec, j_pos, scale_grad, grad_q); // cross product
}
template <typename psum_t>
__device__ __forceinline__ void
rot_backward_rotation(const float3 vec, const float3 grad_vec, psum_t &grad_q, const int8_t axis_sign=1) {
grad_q += axis_sign * dot(vec, grad_vec);
}
template <typename psum_t>
__device__ __forceinline__ void
prism_backward_translation(const float3 vec, const float3 grad_vec,
psum_t &grad_q) {
grad_q += dot(vec, grad_vec);
psum_t &grad_q, const int8_t axis_sign=1) {
grad_q += axis_sign * dot(vec, grad_vec);
}
template <typename psum_t>
__device__ __forceinline__ void
rot_backward_rotation(const float3 vec, const float3 grad_vec, psum_t &grad_q) {
grad_q += dot(vec, grad_vec);
}
template <typename psum_t>
__device__ __forceinline__ void
z_rot_backward(float *link_cumul_mat, float *l_pos, float3 &loc_grad_position,
float3 &loc_grad_orientation, psum_t &grad_q) {
float3 &loc_grad_orientation, psum_t &grad_q, const int8_t axis_sign=1) {
float3 vec =
make_float3(link_cumul_mat[2], link_cumul_mat[6], link_cumul_mat[10]);
// get rotation vector:
rot_backward_translation(vec, &link_cumul_mat[0], &l_pos[0],
loc_grad_position, grad_q);
loc_grad_position, grad_q, axis_sign);
rot_backward_rotation(vec, loc_grad_orientation, grad_q);
rot_backward_rotation(vec, loc_grad_orientation, grad_q, axis_sign);
}
template <typename psum_t>
__device__ __forceinline__ void
x_rot_backward(float *link_cumul_mat, float *l_pos, float3 &loc_grad_position,
float3 &loc_grad_orientation, psum_t &grad_q) {
float3 &loc_grad_orientation, psum_t &grad_q, const int8_t axis_sign=1) {
float3 vec =
make_float3(link_cumul_mat[0], link_cumul_mat[4], link_cumul_mat[8]);
// get rotation vector:
rot_backward_translation(vec, &link_cumul_mat[0], &l_pos[0],
loc_grad_position, grad_q);
loc_grad_position, grad_q, axis_sign);
rot_backward_rotation(vec, loc_grad_orientation, grad_q);
rot_backward_rotation(vec, loc_grad_orientation, grad_q, axis_sign);
}
template <typename psum_t>
__device__ __forceinline__ void
y_rot_backward(float *link_cumul_mat, float *l_pos, float3 &loc_grad_position,
float3 &loc_grad_orientation, psum_t &grad_q) {
float3 &loc_grad_orientation, psum_t &grad_q, const int8_t axis_sign=1) {
float3 vec =
make_float3(link_cumul_mat[1], link_cumul_mat[5], link_cumul_mat[9]);
// get rotation vector:
rot_backward_translation(vec, &link_cumul_mat[0], &l_pos[0],
loc_grad_position, grad_q);
loc_grad_position, grad_q, axis_sign);
rot_backward_rotation(vec, loc_grad_orientation, grad_q);
rot_backward_rotation(vec, loc_grad_orientation, grad_q, axis_sign);
}
template <typename psum_t>
__device__ __forceinline__ void
xyz_prism_backward_translation(float *cumul_mat, float3 &loc_grad,
psum_t &grad_q, int xyz) {
psum_t &grad_q, int xyz, const int8_t axis_sign=1) {
prism_backward_translation(
make_float3(cumul_mat[0 + xyz], cumul_mat[4 + xyz], cumul_mat[8 + xyz]),
loc_grad, grad_q);
loc_grad, grad_q, axis_sign);
}
template <typename psum_t>
__device__ __forceinline__ void x_prism_backward_translation(float *cumul_mat,
float3 &loc_grad,
psum_t &grad_q) {
psum_t &grad_q,
const int8_t axis_sign=1) {
// get rotation vector:
prism_backward_translation(
make_float3(cumul_mat[0], cumul_mat[4], cumul_mat[8]), loc_grad, grad_q);
make_float3(cumul_mat[0], cumul_mat[4], cumul_mat[8]), loc_grad, grad_q, axis_sign);
}
template <typename psum_t>
__device__ __forceinline__ void y_prism_backward_translation(float *cumul_mat,
float3 &loc_grad,
psum_t &grad_q) {
psum_t &grad_q,
const int8_t axis_sign=1) {
// get rotation vector:
prism_backward_translation(
make_float3(cumul_mat[1], cumul_mat[5], cumul_mat[9]), loc_grad, grad_q);
make_float3(cumul_mat[1], cumul_mat[5], cumul_mat[9]), loc_grad, grad_q, axis_sign);
}
template <typename psum_t>
__device__ __forceinline__ void z_prism_backward_translation(float *cumul_mat,
float3 &loc_grad,
psum_t &grad_q) {
psum_t &grad_q,
const int8_t axis_sign=1) {
// get rotation vector:
prism_backward_translation(
make_float3(cumul_mat[2], cumul_mat[6], cumul_mat[10]), loc_grad, grad_q);
make_float3(cumul_mat[2], cumul_mat[6], cumul_mat[10]), loc_grad, grad_q, axis_sign);
}
__device__ __forceinline__ void
xyz_rot_backward_translation(float *cumul_mat, float *l_pos, float3 &loc_grad,
float &grad_q, int xyz) {
float &grad_q, int xyz, const int8_t axis_sign=1) {
// get rotation vector:
rot_backward_translation(
make_float3(cumul_mat[0 + xyz], cumul_mat[4 + xyz], cumul_mat[8 + xyz]),
&cumul_mat[0], &l_pos[0], loc_grad, grad_q);
&cumul_mat[0], &l_pos[0], loc_grad, grad_q, axis_sign);
}
template <typename psum_t>
__device__ __forceinline__ void
x_rot_backward_translation(float *cumul_mat, float *l_pos, float3 &loc_grad,
psum_t &grad_q) {
psum_t &grad_q, const int8_t axis_sign=1) {
// get rotation vector:
rot_backward_translation(
make_float3(cumul_mat[0], cumul_mat[4], cumul_mat[8]), &cumul_mat[0],
&l_pos[0], loc_grad, grad_q);
&l_pos[0], loc_grad, grad_q, axis_sign);
}
template <typename psum_t>
__device__ __forceinline__ void
y_rot_backward_translation(float *cumul_mat, float *l_pos, float3 &loc_grad,
psum_t &grad_q) {
psum_t &grad_q, const int8_t axis_sign=1) {
// get rotation vector:
rot_backward_translation(
make_float3(cumul_mat[1], cumul_mat[5], cumul_mat[9]), &cumul_mat[0],
&l_pos[0], loc_grad, grad_q);
&l_pos[0], loc_grad, grad_q, axis_sign);
}
template <typename psum_t>
__device__ __forceinline__ void
z_rot_backward_translation(float *cumul_mat, float *l_pos, float3 &loc_grad,
psum_t &grad_q) {
psum_t &grad_q, const int8_t axis_sign=1) {
// get rotation vector:
rot_backward_translation(
make_float3(cumul_mat[2], cumul_mat[6], cumul_mat[10]), &cumul_mat[0],
&l_pos[0], loc_grad, grad_q);
&l_pos[0], loc_grad, grad_q, axis_sign);
}
// An optimized version of kin_fused_warp_kernel.
@@ -604,7 +595,7 @@ kin_fused_warp_kernel2(scalar_t *link_pos, // batchSize xz store_n_links x M x M
*(float4 *)&global_cumul_mat[batch * nlinks * 16 + col_idx * M] =
*(float4 *)&cumul_mat[matAddrBase + col_idx * M];
}
for (int8_t l = 1; l < nlinks; l++) // TODO: add base link transform
for (int8_t l = 1; l < nlinks; l++) //
{
// get one row of fixedTransform
@@ -616,21 +607,27 @@ kin_fused_warp_kernel2(scalar_t *link_pos, // batchSize xz store_n_links x M x M
// check joint type and use one of the helper functions:
float JM[M];
int j_type = jointMapType[l];
float angle = q[batch * njoints + jointMap[l]];
if (j_type == FIXED) {
if (j_type == FIXED)
{
fixed_joint_fn(&fixedTransform[ftAddrStart + col_idx], &JM[0]);
} else if (j_type <= Z_PRISM) {
prism_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0], j_type);
} else if (j_type == X_ROT) {
xrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else if (j_type == Y_ROT) {
yrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else if (j_type == Z_ROT) {
zrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else {
assert(jointMapType[l] > -2 &&
jointMapType[l] < 6); // joint type not supported
}
}
else
{
float angle = q[batch * njoints + jointMap[l]];
update_axis_direction(angle, j_type);
if (j_type <= Z_PRISM) {
prism_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0], j_type);
} else if (j_type == X_ROT) {
xrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else if (j_type == Y_ROT) {
yrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else if (j_type == Z_ROT) {
zrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else {
assert(j_type >= FIXED && j_type <= Z_ROT);
}
}
#pragma unroll 4
for (int i = 0; i < M; i++) {
@@ -652,7 +649,8 @@ kin_fused_warp_kernel2(scalar_t *link_pos, // batchSize xz store_n_links x M x M
int16_t read_cumul_idx = -1;
int16_t spheres_perthread = (nspheres + 3) / 4;
for (int16_t i = 0; i < spheres_perthread; i++) {
const int16_t sph_idx = col_idx * spheres_perthread + i;
//const int16_t sph_idx = col_idx * spheres_perthread + i;
const int16_t sph_idx = col_idx + i * 4;
// const int8_t sph_idx =
// i * 4 + col_idx; // different order such that adjacent
// spheres are in neighboring threads
@@ -749,23 +747,30 @@ __global__ void kin_fused_backward_kernel3(
int inAddrStart = matAddrBase + linkMap[l] * M * M;
int outAddrStart = matAddrBase + l * M * M; // + (t % M) * M;
float angle = q[batch * njoints + jointMap[l]];
int j_type = jointMapType[l];
if (j_type == FIXED) {
fixed_joint_fn(&fixedTransform[ftAddrStart + col_idx], &JM[0]);
} else if (j_type == Z_ROT) {
zrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else if (j_type <= Z_PRISM) {
prism_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0], j_type);
} else if (j_type == X_ROT) {
xrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else if (j_type == Y_ROT) {
yrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else {
assert(jointMapType[l] > -2 &&
jointMapType[l] < 6); // joint type not supported
}
}
else {
float angle = q[batch * njoints + jointMap[l]];
update_axis_direction(angle, j_type);
if (j_type <= Z_PRISM) {
prism_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0], j_type);
} else if (j_type == X_ROT) {
xrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else if (j_type == Y_ROT) {
yrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else if (j_type == Z_ROT) {
zrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
} else {
assert(j_type >= FIXED && j_type <= Z_ROT);
}
}
// fetch one row of cumul_mat, multiply with a column, which is in JM
cumul_mat[outAddrStart + elem_idx] =
dot(*(float4 *)&cumul_mat[inAddrStart + ((elem_idx / 4) * M)],
@@ -789,7 +794,8 @@ __global__ void kin_fused_backward_kernel3(
const int spheres_perthread = (nspheres + 15) / 16;
for (int i = 0; i < spheres_perthread; i++) {
const int sph_idx = elem_idx * spheres_perthread + i;
//const int sph_idx = elem_idx * spheres_perthread + i;
const int sph_idx = elem_idx + i * 16;
if (sph_idx >= nspheres) {
break;
}
@@ -808,40 +814,49 @@ __global__ void kin_fused_backward_kernel3(
// read cumul idx:
read_cumul_idx = linkSphereMap[sph_idx];
float spheres_mem[4];
transform_sphere(&cumul_mat[matAddrBase + read_cumul_idx * 16],
&robotSpheres[sphAddrs], &spheres_mem[0]);
for (int j = read_cumul_idx; j > -1; j--) {
// assuming this sphere only depends on links lower than this index
// This could be relaxed by making read_cumul_idx = number of links.
//const int16_t loop_max = read_cumul_idx;
const int16_t loop_max = nlinks - 1;
for (int j = loop_max; j > -1; j--) {
if (linkChainMap[read_cumul_idx * nlinks + j] == 0.0) {
continue;
}
int8_t axis_sign = 1;
int j_type = jointMapType[j];
if(j_type != FIXED)
{
update_joint_type_direction(j_type, axis_sign);
}
if (j_type == Z_ROT) {
float result = 0.0;
z_rot_backward_translation(&cumul_mat[matAddrBase + j * 16],
&spheres_mem[0], loc_grad_sphere, result);
&spheres_mem[0], loc_grad_sphere, result, axis_sign);
psum_grad[jointMap[j]] += (psum_t)result;
} else if (j_type >= X_PRISM && j_type <= Z_PRISM) {
float result = 0.0;
xyz_prism_backward_translation(&cumul_mat[matAddrBase + j * 16],
loc_grad_sphere, result, j_type);
loc_grad_sphere, result, j_type, axis_sign);
psum_grad[jointMap[j]] += (psum_t)result;
} else if (j_type == X_ROT) {
float result = 0.0;
x_rot_backward_translation(&cumul_mat[matAddrBase + j * 16],
&spheres_mem[0], loc_grad_sphere, result);
&spheres_mem[0], loc_grad_sphere, result, axis_sign);
psum_grad[jointMap[j]] += (psum_t)result;
} else if (j_type == Y_ROT) {
float result = 0.0;
y_rot_backward_translation(&cumul_mat[matAddrBase + j * 16],
&spheres_mem[0], loc_grad_sphere, result);
&spheres_mem[0], loc_grad_sphere, result, axis_sign);
psum_grad[jointMap[j]] += (psum_t)result;
}
}
}
// Instead of accumulating the sphere_grad and link_grad separately, we will
// accumulate them together once below.
//
@@ -867,7 +882,7 @@ __global__ void kin_fused_backward_kernel3(
float3 g_position = *(float3 *)&grad_nlinks_pos[batchAddrs * 3 + i * 3];
float4 g_orientation_t =
*(float4 *)&grad_nlinks_quat[batchAddrs * 4 + i * 4];
// TODO: sparisty check here:
// sparisty check here:
if (enable_sparsity_opt) {
if (g_position.x == 0 && g_position.y == 0 && g_position.z == 0 &&
g_orientation_t.y == 0 && g_orientation_t.z == 0 &&
@@ -879,6 +894,7 @@ __global__ void kin_fused_backward_kernel3(
make_float3(g_orientation_t.y, g_orientation_t.z, g_orientation_t.w);
const int16_t l_map = storeLinkMap[i];
float l_pos[3];
const int outAddrStart = matAddrBase + l_map * M * M;
@@ -886,34 +902,47 @@ __global__ void kin_fused_backward_kernel3(
l_pos[1] = cumul_mat[outAddrStart + M + 3];
l_pos[2] = cumul_mat[outAddrStart + M * 2 + 3];
int16_t joints_per_thread = (l_map + 15) / 16;
for (int16_t k = joints_per_thread; k >= 0; k--) {
int16_t j = k * M + elem_idx;
if (j > l_map || j < 0)
const int16_t max_lmap = nlinks - 1;
const int16_t joints_per_thread = (max_lmap + 15) / 16;
//for (int16_t k = joints_per_thread; k >= 0; k--)
for (int16_t k=0; k < joints_per_thread; k++)
{
//int16_t j = elem_idx * joints_per_thread + k;
int16_t j = elem_idx + k * 16;
//int16_t j = k * M + elem_idx;
if (j > max_lmap || j < 0)
continue;
// This can be spread across threads as they are not sequential?
if (linkChainMap[l_map * nlinks + j] == 0.0) {
continue;
}
int16_t j_idx = jointMap[j];
int8_t j_type = jointMapType[j];
int j_type = jointMapType[j];
int8_t axis_sign = 1;
if (j_type != FIXED)
{
update_joint_type_direction(j_type, axis_sign);
}
// get rotation vector:
if (j_type == Z_ROT) {
z_rot_backward(&cumul_mat[matAddrBase + (j)*M * M], &l_pos[0],
g_position, g_orientation, psum_grad[j_idx]);
g_position, g_orientation, psum_grad[j_idx], axis_sign);
} else if (j_type >= X_PRISM & j_type <= Z_PRISM) {
xyz_prism_backward_translation(&cumul_mat[matAddrBase + j * 16],
g_position, psum_grad[j_idx], j_type);
g_position, psum_grad[j_idx], j_type, axis_sign);
} else if (j_type == X_ROT) {
x_rot_backward(&cumul_mat[matAddrBase + (j)*M * M], &l_pos[0],
g_position, g_orientation, psum_grad[j_idx]);
g_position, g_orientation, psum_grad[j_idx], axis_sign);
} else if (j_type == Y_ROT) {
y_rot_backward(&cumul_mat[matAddrBase + (j)*M * M], &l_pos[0],
g_position, g_orientation, psum_grad[j_idx]);
g_position, g_orientation, psum_grad[j_idx], axis_sign);
}
}
}
if (PARALLEL_WRITE) {
// accumulate the partial sums across the 16 threads
@@ -931,7 +960,8 @@ __global__ void kin_fused_backward_kernel3(
#pragma unroll
for (int16_t j = 0; j < joints_per_thread; j++) {
const int16_t j_idx = elem_idx * joints_per_thread + j;
//const int16_t j_idx = elem_idx * joints_per_thread + j;
const int16_t j_idx = elem_idx + j * 16;
if (j_idx >= njoints) {
break;
}

View File

@@ -441,7 +441,7 @@ __global__ void reduce_kernel(
rho_buffer[threadIdx.x * batchsize + batch] = rho;
}
}
template <typename scalar_t, typename psum_t>
template <typename scalar_t, typename psum_t, bool rolled_ys>
__global__ void lbfgs_update_buffer_and_step_v1(
scalar_t *step_vec, // b x 175
scalar_t *rho_buffer, // m x b x 1
@@ -452,7 +452,6 @@ __global__ void lbfgs_update_buffer_and_step_v1(
scalar_t *grad_0, // b x 175
const scalar_t *grad_q, // b x 175
const float epsilon, const int batchsize, const int m, const int v_dim,
const bool rolled_ys = false,
const bool stable_mode =
false) // s_buffer and y_buffer are not rolled by default
{
@@ -485,6 +484,7 @@ __global__ void lbfgs_update_buffer_and_step_v1(
scalar_t s =
q[batch * v_dim + threadIdx.x] - x_0[batch * v_dim + threadIdx.x];
reduce_v1(y * s, v_dim, &data[0], &result);
//reduce(y * s, v_dim, &data[0], &result);
scalar_t numerator = result;
// scalar_t rho = 1.0/numerator;
@@ -824,14 +824,14 @@ lbfgs_cuda_fuse(torch::Tensor step_vec, torch::Tensor rho_buffer,
y_buffer.scalar_type(), "lbfgs_cuda_fuse_kernel_v1", [&] {
smemsize = 3 * m * threadsPerBlock * sizeof(scalar_t) +
m * batch_size * sizeof(scalar_t);
lbfgs_update_buffer_and_step_v1<scalar_t, scalar_t>
lbfgs_update_buffer_and_step_v1<scalar_t, scalar_t, false>
<<<blocksPerGrid, threadsPerBlock, smemsize, stream>>>(
step_vec.data_ptr<scalar_t>(),
rho_buffer.data_ptr<scalar_t>(),
y_buffer.data_ptr<scalar_t>(), s_buffer.data_ptr<scalar_t>(),
q.data_ptr<scalar_t>(), x_0.data_ptr<scalar_t>(),
grad_0.data_ptr<scalar_t>(), grad_q.data_ptr<scalar_t>(),
epsilon, batch_size, m, v_dim, false, stable_mode);
epsilon, batch_size, m, v_dim, stable_mode);
});
}

View File

@@ -164,10 +164,14 @@ compute_distance(float *distance_vec, float &distance, float &position_distance,
distance = 0;
if (rotation_distance > vec_convergence[0] * vec_convergence[0]) {
rotation_distance = sqrtf(rotation_distance);
//rotation_distance -= vec_convergence[0];
distance += rotation_weight * rotation_distance;
}
if (position_distance > vec_convergence[1] * vec_convergence[1]) {
position_distance = sqrtf(position_distance);
//position_distance -= vec_convergence[1];
distance += position_weight * position_distance;
}
}
@@ -202,9 +206,12 @@ __device__ __forceinline__ void compute_metric_distance(
distance = 0;
if (rotation_distance > vec_convergence[0] * vec_convergence[0]) {
rotation_distance = sqrtf(rotation_distance);
//rotation_distance -= vec_convergence[0];
distance += rotation_weight * log2f(coshf(r_alpha * rotation_distance));
}
if (position_distance > vec_convergence[1] * vec_convergence[1]) {
//position_distance -= vec_convergence[1];
position_distance = sqrtf(position_distance);
distance += position_weight * log2f(coshf(p_alpha * position_distance));
}
@@ -372,6 +379,14 @@ __global__ void goalset_pose_distance_kernel(
// write out pose distance:
out_distance[batch_idx * horizon + h_idx] = best_distance;
if (write_distance) {
if(position_weight == 0.0)
{
best_position_distance = 0.0;
}
if (rotation_weight == 0.0)
{
best_rotation_distance = 0.0;
}
out_position_distance[batch_idx * horizon + h_idx] = best_position_distance;
out_rotation_distance[batch_idx * horizon + h_idx] = best_rotation_distance;
}
@@ -522,6 +537,14 @@ __global__ void goalset_pose_distance_metric_kernel(
// write out pose distance:
out_distance[batch_idx * horizon + h_idx] = best_distance;
if (write_distance) {
if(position_weight == 0.0)
{
best_position_distance = 0.0;
}
if (rotation_weight == 0.0)
{
best_rotation_distance = 0.0;
}
out_position_distance[batch_idx * horizon + h_idx] = best_position_distance;
out_rotation_distance[batch_idx * horizon + h_idx] = best_rotation_distance;
}

View File

@@ -79,7 +79,7 @@ std::vector<torch::Tensor> step_position_clique_wrapper(
const torch::Tensor traj_dt, const int batch_size, const int horizon,
const int dof) {
const at::cuda::OptionalCUDAGuard guard(u_position.device());
assert(false); // not supported
CHECK_INPUT(u_position);
CHECK_INPUT(out_position);
CHECK_INPUT(out_velocity);
@@ -155,7 +155,7 @@ std::vector<torch::Tensor> backward_step_position_clique_wrapper(
const torch::Tensor grad_jerk, const torch::Tensor traj_dt,
const int batch_size, const int horizon, const int dof) {
const at::cuda::OptionalCUDAGuard guard(grad_position.device());
assert(false); // not supported
CHECK_INPUT(out_grad_position);
CHECK_INPUT(grad_position);
CHECK_INPUT(grad_velocity);

View File

@@ -319,14 +319,16 @@ __device__ __forceinline__ void compute_central_difference(scalar_t *out_positio
const float dt = traj_dt[0]; // assume same dt across traj TODO: Implement variable dt
// dt here is actually 1/dt;
const float dt_inv = 1.0 / dt;
const float st_jerk = 0.0; // Note: start jerk can also be passed from global memory
// read start state:
float out_pos=0.0, out_vel=0.0, out_acc=0.0, out_jerk=0.0;
float st_pos=0.0, st_vel=0.0, st_acc = 0.0;
const int b_addrs = b_idx * horizon * dof;
const int b_addrs_action = b_idx * (horizon-4) * dof;
float in_pos[5]; // create a 5 value scalar
const float acc_scale = 1.0;
#pragma unroll 5
for (int i=0; i< 5; i ++){
in_pos[i] = 0.0;
@@ -337,92 +339,108 @@ __device__ __forceinline__ void compute_central_difference(scalar_t *out_positio
st_acc = start_acceleration[b_offset * dof + d_idx];
}
if (h_idx > 3 && h_idx < horizon - 4)
{
in_pos[0] = u_position[b_addrs + (h_idx - 4) * dof + d_idx];
in_pos[1] = u_position[b_addrs + (h_idx - 3) * dof + d_idx];
in_pos[2] = u_position[b_addrs + (h_idx - 2 ) * dof + d_idx];
in_pos[3] = u_position[b_addrs + (h_idx - 1) * dof + d_idx];
in_pos[4] = u_position[b_addrs + (h_idx) * dof + d_idx];
if (h_idx > 3 && h_idx < horizon - 4){
in_pos[0] = u_position[b_addrs_action + (h_idx - 4) * dof + d_idx];
in_pos[1] = u_position[b_addrs_action + (h_idx - 3) * dof + d_idx];
in_pos[2] = u_position[b_addrs_action + (h_idx - 2 ) * dof + d_idx];
in_pos[3] = u_position[b_addrs_action + (h_idx - 1) * dof + d_idx];
in_pos[4] = u_position[b_addrs_action + (h_idx) * dof + d_idx];
}
else if (h_idx == 0)
{
in_pos[0] = st_pos - 3 * dt * ( st_vel + (0.5 * st_acc * dt)); // start -1, start, u0, u1
in_pos[1] = st_pos - 2 * dt * ( st_vel + (0.5 * st_acc * dt));
in_pos[2] = st_pos - dt * ( st_vel + (0.5 * st_acc * dt));
in_pos[0] = (3.0f/2) * ( - 1 * st_acc * (dt_inv * dt_inv) - (dt_inv * dt_inv * dt_inv) * st_jerk ) - 3.0f * dt_inv * st_vel + st_pos;
in_pos[1] = -2.0f * st_acc * dt_inv * dt_inv - (4.0f/3) * dt_inv * dt_inv * dt_inv * st_jerk - 2.0 * dt_inv * st_vel + st_pos;
in_pos[2] = -(3.0f/2) * st_acc * dt_inv * dt_inv - (7.0f/6) * dt_inv * dt_inv * dt_inv * st_jerk - dt_inv * st_vel + st_pos;
in_pos[3] = st_pos;
in_pos[4] = u_position[b_addrs + (h_idx) * dof + d_idx];
in_pos[4] = u_position[b_addrs_action + (h_idx) * dof + d_idx];
}
else if (h_idx == 1)
{
in_pos[0] = st_pos - 2 * dt * ( st_vel + (0.5 * st_acc * dt)); // start -1, start, u0, u1
in_pos[1] = st_pos - dt * ( st_vel + (0.5 * st_acc * dt));
in_pos[0] = -2.0f * st_acc * dt_inv * dt_inv - (4.0f/3) * dt_inv * dt_inv * dt_inv * st_jerk - 2.0 * dt_inv * st_vel + st_pos;
in_pos[1] = -(3.0f/2) * st_acc * dt_inv * dt_inv - (7.0f/6) * dt_inv * dt_inv * dt_inv * st_jerk - dt_inv * st_vel + st_pos;
in_pos[2] = st_pos;
in_pos[3] = u_position[b_addrs + (h_idx - 1) * dof + d_idx];
in_pos[4] = u_position[b_addrs + (h_idx) * dof + d_idx];
in_pos[3] = u_position[b_addrs_action + (h_idx - 1) * dof + d_idx];
in_pos[4] = u_position[b_addrs_action + (h_idx) * dof + d_idx];
}
else if (h_idx == 2)
{
in_pos[0] = st_pos - dt * ( st_vel + (0.5 * st_acc * dt)); // start -1, start, u0, u1
in_pos[0] = -(3.0f/2) * st_acc * dt_inv * dt_inv - (7.0f/6) * dt_inv * dt_inv * dt_inv * st_jerk - dt_inv * st_vel + st_pos;
in_pos[1] = st_pos;
in_pos[2] = u_position[b_addrs + (h_idx - 2) * dof + d_idx];
in_pos[3] = u_position[b_addrs + (h_idx - 1) * dof + d_idx];
in_pos[4] = u_position[b_addrs + (h_idx) * dof + d_idx];
in_pos[2] = u_position[b_addrs_action + (h_idx - 2) * dof + d_idx];
in_pos[3] = u_position[b_addrs_action + (h_idx - 1) * dof + d_idx];
in_pos[4] = u_position[b_addrs_action + (h_idx) * dof + d_idx];
}
else if (h_idx == 3)
{
in_pos[0] = st_pos;
in_pos[1] = u_position[b_addrs + (h_idx - 3) * dof + d_idx];
in_pos[2] = u_position[b_addrs + (h_idx - 2 ) * dof + d_idx];
in_pos[3] = u_position[b_addrs + (h_idx - 1 ) * dof + d_idx];
in_pos[4] = u_position[b_addrs + (h_idx) * dof + d_idx];
in_pos[1] = u_position[b_addrs_action + (h_idx - 3) * dof + d_idx];
in_pos[2] = u_position[b_addrs_action + (h_idx - 2 ) * dof + d_idx];
in_pos[3] = u_position[b_addrs_action + (h_idx - 1 ) * dof + d_idx];
in_pos[4] = u_position[b_addrs_action + (h_idx) * dof + d_idx];
}
else if (h_idx == horizon - 4)
{
in_pos[0] = u_position[b_addrs + (h_idx -4) * dof + d_idx];
in_pos[1] = u_position[b_addrs + (h_idx - 3) * dof + d_idx];
in_pos[2] = u_position[b_addrs + (h_idx - 2 ) * dof + d_idx];
in_pos[3] = u_position[b_addrs + (h_idx - 1) * dof + d_idx];
in_pos[4] = in_pos[3];//in_pos[3]; //u_position[b_addrs + (h_idx - 1 + 2) * dof + d_idx];
in_pos[0] = u_position[b_addrs_action + (h_idx -4) * dof + d_idx];
in_pos[1] = u_position[b_addrs_action + (h_idx - 3) * dof + d_idx];
in_pos[2] = u_position[b_addrs_action + (h_idx - 2 ) * dof + d_idx];
in_pos[3] = u_position[b_addrs_action + (h_idx - 1) * dof + d_idx];
in_pos[4] = in_pos[3];//in_pos[3]; //u_position[b_addrs_action + (h_idx - 1 + 2) * dof + d_idx];
}
else if (h_idx == horizon - 3)
{
in_pos[0] = u_position[b_addrs + (h_idx - 4) * dof + d_idx];
in_pos[1] = u_position[b_addrs + (h_idx - 3) * dof + d_idx];
in_pos[2] = u_position[b_addrs + (h_idx - 2) * dof + d_idx];
in_pos[3] = in_pos[2];//u_position[b_addrs + (h_idx - 1 + 1) * dof + d_idx];
in_pos[4] = in_pos[2];//in_pos[3]; //u_position[b_addrs + (h_idx - 1 + 2) * dof + d_idx];
in_pos[0] = u_position[b_addrs_action + (h_idx - 4) * dof + d_idx];
in_pos[1] = u_position[b_addrs_action + (h_idx - 3) * dof + d_idx];
in_pos[2] = u_position[b_addrs_action + (h_idx - 2) * dof + d_idx];
in_pos[3] = in_pos[2];//u_position[b_addrs_action + (h_idx - 1 + 1) * dof + d_idx];
in_pos[4] = in_pos[2];//in_pos[3]; //u_position[b_addrs_action + (h_idx - 1 + 2) * dof + d_idx];
}
else if (h_idx == horizon - 2)
{
in_pos[0] = u_position[b_addrs + (h_idx - 4) * dof + d_idx];
in_pos[1] = u_position[b_addrs + (h_idx - 3) * dof + d_idx];
in_pos[0] = u_position[b_addrs_action + (h_idx - 4) * dof + d_idx];
in_pos[1] = u_position[b_addrs_action + (h_idx - 3) * dof + d_idx];
in_pos[2] = in_pos[1];
in_pos[3] = in_pos[1];//u_position[b_addrs + (h_idx - 1 + 1) * dof + d_idx];
in_pos[4] = in_pos[1];//in_pos[3]; //u_position[b_addrs + (h_idx - 1 + 2) * dof + d_idx];
in_pos[3] = in_pos[1];//u_position[b_addrs_action + (h_idx - 1 + 1) * dof + d_idx];
in_pos[4] = in_pos[1];//in_pos[3]; //u_position[b_addrs_action + (h_idx - 1 + 2) * dof + d_idx];
}
else if (h_idx == horizon -1)
else if (h_idx == horizon - 1)
{
in_pos[0] = u_position[b_addrs + (h_idx - 4) * dof + d_idx];
in_pos[0] = u_position[b_addrs_action + (h_idx - 4) * dof + d_idx];
in_pos[1] = in_pos[0];
in_pos[2] = in_pos[0];//u_position[b_addrs + (h_idx - 1 ) * dof + d_idx];
in_pos[3] = in_pos[0];//u_position[b_addrs + (h_idx - 1 + 1) * dof + d_idx];
in_pos[4] = in_pos[0];//in_pos[3]; //u_position[b_addrs + (h_idx - 1 + 2) * dof + d_idx];
in_pos[2] = in_pos[0];//u_position[b_addrs_action + (h_idx - 1 ) * dof + d_idx];
in_pos[3] = in_pos[0];//u_position[b_addrs_action + (h_idx - 1 + 1) * dof + d_idx];
in_pos[4] = in_pos[0];//in_pos[3]; //u_position[b_addrs_action + (h_idx - 1 + 2) * dof + d_idx];
}
out_pos = in_pos[2];
// out_vel = (0.5 * in_pos[3] - 0.5 * in_pos[1]) * dt;
out_vel = ((0.083333333f) * in_pos[0] - (0.666666667f) * in_pos[1] + (0.666666667f) * in_pos[3] + (-0.083333333f) * in_pos[4]) * dt;
@@ -693,8 +711,9 @@ __global__ void backward_position_clique_loop_central_difference_kernel2(
return;
}
const int b_addrs = b_idx * horizon * dof;
const int b_addrs_action = b_idx * (horizon-4) * dof;
if (h_idx < 2 || h_idx > horizon - 2)
if (h_idx < 2 || h_idx >= horizon - 2)
{
return;
}
@@ -717,7 +736,7 @@ __global__ void backward_position_clique_loop_central_difference_kernel2(
g_jerk[i] = 0.0;
}
int hid = h_idx;
const int hid = h_idx;
g_pos[0] = grad_position[b_addrs + (hid)*dof + d_idx];
g_pos[1] = 0.0;
@@ -745,30 +764,35 @@ __global__ void backward_position_clique_loop_central_difference_kernel2(
(0.5f * g_jerk[0] - g_jerk[1] + g_jerk[3] - 0.5f * g_jerk[4]) * dt * dt * dt);
}
else if (hid == horizon -3)
else if (hid == horizon - 3)
{
//The below can cause oscilatory gradient steps.
/*
#pragma unroll
for (int i=0; i< 4; i++)
for (int i=0; i< 5; i++)
{
g_vel[i] = grad_velocity[b_addrs + ((hid - 2) + i)*dof + d_idx];
g_acc[i] = grad_acceleration[b_addrs + ((hid -2) + i)*dof + d_idx];
g_jerk[i] = grad_jerk[b_addrs + ((hid -2) + i)*dof + d_idx];
}
*/
g_pos[1] = grad_position[b_addrs + (hid + 1)*dof + d_idx];
g_pos[2] = grad_position[b_addrs + (hid + 2)*dof + d_idx];
out_grad = (g_pos[0] + g_pos[1] + g_pos[2] +
out_grad = (g_pos[0] + g_pos[1] + g_pos[2]);
/* +
//((0.5) * g_vel[1] + (0.5) * g_vel[2]) * dt +
((-0.083333333f) * g_vel[0] + (0.583333333f) * g_vel[1] + (0.583333333f) * g_vel[2] + (-0.083333333f) * g_vel[3]) * dt +
((-0.083333333f) * g_acc[0] + (1.25f) * g_acc[1] + (-1.25f) * g_acc[2] + (0.083333333f) * g_acc[3]) * dt * dt +
//( g_acc[1] - g_acc[2]) * dt * dt +
(0.5f * g_jerk[0] - 0.5f * g_jerk[1] -0.5f * g_jerk[2] + 0.5f * g_jerk[3]) * dt * dt * dt);
*/
}
// write out:
out_grad_position[b_addrs + (h_idx-2)*dof + d_idx] = out_grad;
out_grad_position[b_addrs_action + (h_idx-2)*dof + d_idx] = out_grad;
}
// for MPPI:
@@ -1389,6 +1413,8 @@ std::vector<torch::Tensor> step_position_clique2(
}
else if (mode == CENTRAL_DIFF)
{
assert(u_position.sizes()[1] == horizon - 4);
AT_DISPATCH_FLOATING_TYPES(
out_position.scalar_type(), "step_position_clique", ([&] {
position_clique_loop_kernel2<scalar_t, CENTRAL_DIFF>
@@ -1436,7 +1462,7 @@ std::vector<torch::Tensor> step_position_clique2_idx(
if (mode == BWD_DIFF)
{
assert(false);
AT_DISPATCH_FLOATING_TYPES(
out_position.scalar_type(), "step_position_clique", ([&] {
position_clique_loop_idx_kernel2<scalar_t, BWD_DIFF>
@@ -1455,6 +1481,8 @@ std::vector<torch::Tensor> step_position_clique2_idx(
else if (mode == CENTRAL_DIFF)
{
assert(u_position.sizes()[1] == horizon - 4);
AT_DISPATCH_FLOATING_TYPES(
out_position.scalar_type(), "step_position_clique", ([&] {
position_clique_loop_idx_kernel2<scalar_t, CENTRAL_DIFF>
@@ -1538,7 +1566,7 @@ std::vector<torch::Tensor> backward_step_position_clique2(
if (mode == BWD_DIFF)
{
assert(false); // not supported anymore
AT_DISPATCH_FLOATING_TYPES(
out_grad_position.scalar_type(), "backward_step_position_clique", ([&] {
backward_position_clique_loop_backward_difference_kernel2<scalar_t>
@@ -1554,7 +1582,7 @@ std::vector<torch::Tensor> backward_step_position_clique2(
}
else if (mode == CENTRAL_DIFF)
{
assert(out_grad_position.sizes()[1] == horizon - 4);
AT_DISPATCH_FLOATING_TYPES(
out_grad_position.scalar_type(), "backward_step_position_clique", ([&] {
backward_position_clique_loop_central_difference_kernel2<scalar_t>