update to 0.6.2
This commit is contained in:
@@ -28,8 +28,15 @@
|
||||
#define Y_ROT 4
|
||||
#define Z_ROT 5
|
||||
|
||||
#define X_PRISM_NEG 6
|
||||
#define Y_PRISM_NEG 7
|
||||
#define Z_PRISM_NEG 8
|
||||
#define X_ROT_NEG 9
|
||||
#define Y_ROT_NEG 10
|
||||
#define Z_ROT_NEG 11
|
||||
|
||||
#define MAX_BATCH_PER_BLOCK 32 // tunable parameter for improving occupancy
|
||||
#define MAX_BW_BATCH_PER_BLOCK 8 // tunable parameter for improving occupancy
|
||||
#define MAX_BW_BATCH_PER_BLOCK 16 // tunable parameter for improving occupancy
|
||||
|
||||
#define MAX_TOTAL_LINKS \
|
||||
750 // limited by shared memory size. We need to fit 16 * float32 per link
|
||||
@@ -238,67 +245,46 @@ __device__ __forceinline__ void prism_fn(const scalar_t *fixedTransform,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ void xprism_fn(const scalar_t *fixedTransform,
|
||||
const float angle, const int col_idx,
|
||||
float *JM) {
|
||||
__device__ __forceinline__ void update_joint_type_direction(int &j_type, int8_t &axis_sign)
|
||||
{
|
||||
// Assume that input j_type >= 0 . Check fixed joint outside of this function.
|
||||
|
||||
switch (col_idx) {
|
||||
case 0:
|
||||
case 1:
|
||||
case 2:
|
||||
fixed_joint_fn(&fixedTransform[col_idx], &JM[0]);
|
||||
break;
|
||||
case 3:
|
||||
JM[0] = fixedTransform[0] * angle + fixedTransform[3]; // FT_0[1];
|
||||
JM[1] = fixedTransform[M] * angle + fixedTransform[M + 3]; // FT_1[1];
|
||||
JM[2] =
|
||||
fixedTransform[M + M] * angle + fixedTransform[M + M + 3]; // FT_2[1];
|
||||
JM[3] = 1;
|
||||
break;
|
||||
// Don't do anything if j_type < 6. j_type range is [0, 11]
|
||||
// divergence here.
|
||||
axis_sign = 1;
|
||||
if (j_type >= 6)
|
||||
{
|
||||
j_type -= 6;
|
||||
axis_sign = -1;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ void yprism_fn(const scalar_t *fixedTransform,
|
||||
const float angle, const int col_idx,
|
||||
float *JM) {
|
||||
|
||||
switch (col_idx) {
|
||||
case 0:
|
||||
case 1:
|
||||
case 2:
|
||||
fixed_joint_fn(&fixedTransform[col_idx], &JM[0]);
|
||||
break;
|
||||
case 3:
|
||||
JM[0] = fixedTransform[1] * angle + fixedTransform[3]; // FT_0[1];
|
||||
JM[1] = fixedTransform[M + 1] * angle + fixedTransform[M + 3]; // FT_1[1];
|
||||
JM[2] = fixedTransform[M + M + 1] * angle +
|
||||
fixedTransform[M + M + 3]; // FT_2[1];
|
||||
JM[3] = 1;
|
||||
break;
|
||||
__device__ __forceinline__ void update_joint_type_direction(int &j_type)
|
||||
{
|
||||
// Assume that input j_type >= 0 . Check fixed joint outside of this function.
|
||||
|
||||
// Don't do anything if j_type < 6. j_type range is [0, 11]
|
||||
// divergence here.
|
||||
if (j_type >= 6)
|
||||
{
|
||||
j_type -= 6;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ void zprism_fn(const scalar_t *fixedTransform,
|
||||
const float angle, const int col_idx,
|
||||
float *JM) {
|
||||
|
||||
switch (col_idx) {
|
||||
case 0:
|
||||
case 1:
|
||||
case 2:
|
||||
fixed_joint_fn(&fixedTransform[col_idx], &JM[0]);
|
||||
break;
|
||||
case 3:
|
||||
JM[0] = fixedTransform[2] * angle + fixedTransform[3]; // FT_0[1];
|
||||
JM[1] = fixedTransform[M + 2] * angle + fixedTransform[M + 3]; // FT_1[1];
|
||||
JM[2] = fixedTransform[M + M + 2] * angle +
|
||||
fixedTransform[M + M + 3]; // FT_2[1];
|
||||
JM[3] = 1;
|
||||
break;
|
||||
}
|
||||
__device__ __forceinline__ void update_axis_direction(
|
||||
float &angle,
|
||||
int &j_type)
|
||||
{
|
||||
// Assume that input j_type >= 0 . Check fixed joint outside of this function.
|
||||
// sign should be +ve <= 5 and -ve >5
|
||||
// j_type range is [0, 11].
|
||||
// cuda code treats -1.0 * 0.0 as negative. Hence we subtract 6. If in future, -1.0 * 0.0 = +ve,
|
||||
// then this code should be j_type - 5.
|
||||
angle = -1 * copysignf(1.0, j_type - 6) * angle;
|
||||
|
||||
update_joint_type_direction(j_type);
|
||||
|
||||
}
|
||||
|
||||
// In the following versions of rot_fn, some non-nan values may become nan as we
|
||||
@@ -423,7 +409,7 @@ __device__ __forceinline__ void zrot_fn(const scalar_t *fixedTransform,
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void
|
||||
rot_backward_translation(const float3 &vec, float *cumul_mat, float *l_pos,
|
||||
const float3 &loc_grad, psum_t &grad_q) {
|
||||
const float3 &loc_grad, psum_t &grad_q, const int8_t axis_sign=1) {
|
||||
|
||||
float3 e_pos, j_pos;
|
||||
|
||||
@@ -433,136 +419,141 @@ rot_backward_translation(const float3 &vec, float *cumul_mat, float *l_pos,
|
||||
|
||||
// compute position gradient:
|
||||
j_pos = *(float3 *)&l_pos[0] - e_pos; // - e_pos;
|
||||
scale_cross_sum(vec, j_pos, loc_grad, grad_q); // cross product
|
||||
float3 scale_grad = axis_sign * loc_grad;
|
||||
scale_cross_sum(vec, j_pos, scale_grad, grad_q); // cross product
|
||||
}
|
||||
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void
|
||||
rot_backward_rotation(const float3 vec, const float3 grad_vec, psum_t &grad_q, const int8_t axis_sign=1) {
|
||||
grad_q += axis_sign * dot(vec, grad_vec);
|
||||
}
|
||||
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void
|
||||
prism_backward_translation(const float3 vec, const float3 grad_vec,
|
||||
psum_t &grad_q) {
|
||||
grad_q += dot(vec, grad_vec);
|
||||
psum_t &grad_q, const int8_t axis_sign=1) {
|
||||
grad_q += axis_sign * dot(vec, grad_vec);
|
||||
}
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void
|
||||
rot_backward_rotation(const float3 vec, const float3 grad_vec, psum_t &grad_q) {
|
||||
grad_q += dot(vec, grad_vec);
|
||||
}
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void
|
||||
z_rot_backward(float *link_cumul_mat, float *l_pos, float3 &loc_grad_position,
|
||||
float3 &loc_grad_orientation, psum_t &grad_q) {
|
||||
float3 &loc_grad_orientation, psum_t &grad_q, const int8_t axis_sign=1) {
|
||||
float3 vec =
|
||||
make_float3(link_cumul_mat[2], link_cumul_mat[6], link_cumul_mat[10]);
|
||||
|
||||
// get rotation vector:
|
||||
rot_backward_translation(vec, &link_cumul_mat[0], &l_pos[0],
|
||||
loc_grad_position, grad_q);
|
||||
loc_grad_position, grad_q, axis_sign);
|
||||
|
||||
rot_backward_rotation(vec, loc_grad_orientation, grad_q);
|
||||
rot_backward_rotation(vec, loc_grad_orientation, grad_q, axis_sign);
|
||||
}
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void
|
||||
x_rot_backward(float *link_cumul_mat, float *l_pos, float3 &loc_grad_position,
|
||||
float3 &loc_grad_orientation, psum_t &grad_q) {
|
||||
float3 &loc_grad_orientation, psum_t &grad_q, const int8_t axis_sign=1) {
|
||||
float3 vec =
|
||||
make_float3(link_cumul_mat[0], link_cumul_mat[4], link_cumul_mat[8]);
|
||||
|
||||
// get rotation vector:
|
||||
rot_backward_translation(vec, &link_cumul_mat[0], &l_pos[0],
|
||||
loc_grad_position, grad_q);
|
||||
loc_grad_position, grad_q, axis_sign);
|
||||
|
||||
rot_backward_rotation(vec, loc_grad_orientation, grad_q);
|
||||
rot_backward_rotation(vec, loc_grad_orientation, grad_q, axis_sign);
|
||||
}
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void
|
||||
y_rot_backward(float *link_cumul_mat, float *l_pos, float3 &loc_grad_position,
|
||||
float3 &loc_grad_orientation, psum_t &grad_q) {
|
||||
float3 &loc_grad_orientation, psum_t &grad_q, const int8_t axis_sign=1) {
|
||||
float3 vec =
|
||||
make_float3(link_cumul_mat[1], link_cumul_mat[5], link_cumul_mat[9]);
|
||||
|
||||
// get rotation vector:
|
||||
rot_backward_translation(vec, &link_cumul_mat[0], &l_pos[0],
|
||||
loc_grad_position, grad_q);
|
||||
loc_grad_position, grad_q, axis_sign);
|
||||
|
||||
rot_backward_rotation(vec, loc_grad_orientation, grad_q);
|
||||
rot_backward_rotation(vec, loc_grad_orientation, grad_q, axis_sign);
|
||||
}
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void
|
||||
xyz_prism_backward_translation(float *cumul_mat, float3 &loc_grad,
|
||||
psum_t &grad_q, int xyz) {
|
||||
psum_t &grad_q, int xyz, const int8_t axis_sign=1) {
|
||||
prism_backward_translation(
|
||||
make_float3(cumul_mat[0 + xyz], cumul_mat[4 + xyz], cumul_mat[8 + xyz]),
|
||||
loc_grad, grad_q);
|
||||
loc_grad, grad_q, axis_sign);
|
||||
}
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void x_prism_backward_translation(float *cumul_mat,
|
||||
float3 &loc_grad,
|
||||
psum_t &grad_q) {
|
||||
psum_t &grad_q,
|
||||
const int8_t axis_sign=1) {
|
||||
// get rotation vector:
|
||||
prism_backward_translation(
|
||||
make_float3(cumul_mat[0], cumul_mat[4], cumul_mat[8]), loc_grad, grad_q);
|
||||
make_float3(cumul_mat[0], cumul_mat[4], cumul_mat[8]), loc_grad, grad_q, axis_sign);
|
||||
}
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void y_prism_backward_translation(float *cumul_mat,
|
||||
float3 &loc_grad,
|
||||
psum_t &grad_q) {
|
||||
psum_t &grad_q,
|
||||
const int8_t axis_sign=1) {
|
||||
// get rotation vector:
|
||||
prism_backward_translation(
|
||||
make_float3(cumul_mat[1], cumul_mat[5], cumul_mat[9]), loc_grad, grad_q);
|
||||
make_float3(cumul_mat[1], cumul_mat[5], cumul_mat[9]), loc_grad, grad_q, axis_sign);
|
||||
}
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void z_prism_backward_translation(float *cumul_mat,
|
||||
float3 &loc_grad,
|
||||
psum_t &grad_q) {
|
||||
psum_t &grad_q,
|
||||
const int8_t axis_sign=1) {
|
||||
// get rotation vector:
|
||||
prism_backward_translation(
|
||||
make_float3(cumul_mat[2], cumul_mat[6], cumul_mat[10]), loc_grad, grad_q);
|
||||
make_float3(cumul_mat[2], cumul_mat[6], cumul_mat[10]), loc_grad, grad_q, axis_sign);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
xyz_rot_backward_translation(float *cumul_mat, float *l_pos, float3 &loc_grad,
|
||||
float &grad_q, int xyz) {
|
||||
float &grad_q, int xyz, const int8_t axis_sign=1) {
|
||||
// get rotation vector:
|
||||
rot_backward_translation(
|
||||
make_float3(cumul_mat[0 + xyz], cumul_mat[4 + xyz], cumul_mat[8 + xyz]),
|
||||
&cumul_mat[0], &l_pos[0], loc_grad, grad_q);
|
||||
&cumul_mat[0], &l_pos[0], loc_grad, grad_q, axis_sign);
|
||||
}
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void
|
||||
x_rot_backward_translation(float *cumul_mat, float *l_pos, float3 &loc_grad,
|
||||
psum_t &grad_q) {
|
||||
psum_t &grad_q, const int8_t axis_sign=1) {
|
||||
// get rotation vector:
|
||||
rot_backward_translation(
|
||||
make_float3(cumul_mat[0], cumul_mat[4], cumul_mat[8]), &cumul_mat[0],
|
||||
&l_pos[0], loc_grad, grad_q);
|
||||
&l_pos[0], loc_grad, grad_q, axis_sign);
|
||||
}
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void
|
||||
y_rot_backward_translation(float *cumul_mat, float *l_pos, float3 &loc_grad,
|
||||
psum_t &grad_q) {
|
||||
psum_t &grad_q, const int8_t axis_sign=1) {
|
||||
// get rotation vector:
|
||||
rot_backward_translation(
|
||||
make_float3(cumul_mat[1], cumul_mat[5], cumul_mat[9]), &cumul_mat[0],
|
||||
&l_pos[0], loc_grad, grad_q);
|
||||
&l_pos[0], loc_grad, grad_q, axis_sign);
|
||||
}
|
||||
|
||||
template <typename psum_t>
|
||||
__device__ __forceinline__ void
|
||||
z_rot_backward_translation(float *cumul_mat, float *l_pos, float3 &loc_grad,
|
||||
psum_t &grad_q) {
|
||||
psum_t &grad_q, const int8_t axis_sign=1) {
|
||||
// get rotation vector:
|
||||
rot_backward_translation(
|
||||
make_float3(cumul_mat[2], cumul_mat[6], cumul_mat[10]), &cumul_mat[0],
|
||||
&l_pos[0], loc_grad, grad_q);
|
||||
&l_pos[0], loc_grad, grad_q, axis_sign);
|
||||
}
|
||||
|
||||
// An optimized version of kin_fused_warp_kernel.
|
||||
@@ -604,7 +595,7 @@ kin_fused_warp_kernel2(scalar_t *link_pos, // batchSize xz store_n_links x M x M
|
||||
*(float4 *)&global_cumul_mat[batch * nlinks * 16 + col_idx * M] =
|
||||
*(float4 *)&cumul_mat[matAddrBase + col_idx * M];
|
||||
}
|
||||
for (int8_t l = 1; l < nlinks; l++) // TODO: add base link transform
|
||||
for (int8_t l = 1; l < nlinks; l++) //
|
||||
{
|
||||
|
||||
// get one row of fixedTransform
|
||||
@@ -616,21 +607,27 @@ kin_fused_warp_kernel2(scalar_t *link_pos, // batchSize xz store_n_links x M x M
|
||||
// check joint type and use one of the helper functions:
|
||||
float JM[M];
|
||||
int j_type = jointMapType[l];
|
||||
float angle = q[batch * njoints + jointMap[l]];
|
||||
if (j_type == FIXED) {
|
||||
|
||||
if (j_type == FIXED)
|
||||
{
|
||||
fixed_joint_fn(&fixedTransform[ftAddrStart + col_idx], &JM[0]);
|
||||
} else if (j_type <= Z_PRISM) {
|
||||
prism_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0], j_type);
|
||||
} else if (j_type == X_ROT) {
|
||||
xrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else if (j_type == Y_ROT) {
|
||||
yrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else if (j_type == Z_ROT) {
|
||||
zrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else {
|
||||
assert(jointMapType[l] > -2 &&
|
||||
jointMapType[l] < 6); // joint type not supported
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
float angle = q[batch * njoints + jointMap[l]];
|
||||
update_axis_direction(angle, j_type);
|
||||
if (j_type <= Z_PRISM) {
|
||||
prism_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0], j_type);
|
||||
} else if (j_type == X_ROT) {
|
||||
xrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else if (j_type == Y_ROT) {
|
||||
yrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else if (j_type == Z_ROT) {
|
||||
zrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else {
|
||||
assert(j_type >= FIXED && j_type <= Z_ROT);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll 4
|
||||
for (int i = 0; i < M; i++) {
|
||||
@@ -652,7 +649,8 @@ kin_fused_warp_kernel2(scalar_t *link_pos, // batchSize xz store_n_links x M x M
|
||||
int16_t read_cumul_idx = -1;
|
||||
int16_t spheres_perthread = (nspheres + 3) / 4;
|
||||
for (int16_t i = 0; i < spheres_perthread; i++) {
|
||||
const int16_t sph_idx = col_idx * spheres_perthread + i;
|
||||
//const int16_t sph_idx = col_idx * spheres_perthread + i;
|
||||
const int16_t sph_idx = col_idx + i * 4;
|
||||
// const int8_t sph_idx =
|
||||
// i * 4 + col_idx; // different order such that adjacent
|
||||
// spheres are in neighboring threads
|
||||
@@ -749,23 +747,30 @@ __global__ void kin_fused_backward_kernel3(
|
||||
int inAddrStart = matAddrBase + linkMap[l] * M * M;
|
||||
int outAddrStart = matAddrBase + l * M * M; // + (t % M) * M;
|
||||
|
||||
float angle = q[batch * njoints + jointMap[l]];
|
||||
int j_type = jointMapType[l];
|
||||
|
||||
|
||||
if (j_type == FIXED) {
|
||||
fixed_joint_fn(&fixedTransform[ftAddrStart + col_idx], &JM[0]);
|
||||
} else if (j_type == Z_ROT) {
|
||||
zrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else if (j_type <= Z_PRISM) {
|
||||
prism_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0], j_type);
|
||||
} else if (j_type == X_ROT) {
|
||||
xrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else if (j_type == Y_ROT) {
|
||||
yrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else {
|
||||
assert(jointMapType[l] > -2 &&
|
||||
jointMapType[l] < 6); // joint type not supported
|
||||
}
|
||||
}
|
||||
else {
|
||||
float angle = q[batch * njoints + jointMap[l]];
|
||||
|
||||
update_axis_direction(angle, j_type);
|
||||
|
||||
if (j_type <= Z_PRISM) {
|
||||
prism_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0], j_type);
|
||||
} else if (j_type == X_ROT) {
|
||||
xrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else if (j_type == Y_ROT) {
|
||||
yrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else if (j_type == Z_ROT) {
|
||||
zrot_fn(&fixedTransform[ftAddrStart], angle, col_idx, &JM[0]);
|
||||
} else {
|
||||
assert(j_type >= FIXED && j_type <= Z_ROT);
|
||||
}
|
||||
|
||||
}
|
||||
// fetch one row of cumul_mat, multiply with a column, which is in JM
|
||||
cumul_mat[outAddrStart + elem_idx] =
|
||||
dot(*(float4 *)&cumul_mat[inAddrStart + ((elem_idx / 4) * M)],
|
||||
@@ -789,7 +794,8 @@ __global__ void kin_fused_backward_kernel3(
|
||||
const int spheres_perthread = (nspheres + 15) / 16;
|
||||
|
||||
for (int i = 0; i < spheres_perthread; i++) {
|
||||
const int sph_idx = elem_idx * spheres_perthread + i;
|
||||
//const int sph_idx = elem_idx * spheres_perthread + i;
|
||||
const int sph_idx = elem_idx + i * 16;
|
||||
if (sph_idx >= nspheres) {
|
||||
break;
|
||||
}
|
||||
@@ -808,40 +814,49 @@ __global__ void kin_fused_backward_kernel3(
|
||||
|
||||
// read cumul idx:
|
||||
read_cumul_idx = linkSphereMap[sph_idx];
|
||||
|
||||
float spheres_mem[4];
|
||||
transform_sphere(&cumul_mat[matAddrBase + read_cumul_idx * 16],
|
||||
&robotSpheres[sphAddrs], &spheres_mem[0]);
|
||||
for (int j = read_cumul_idx; j > -1; j--) {
|
||||
// assuming this sphere only depends on links lower than this index
|
||||
// This could be relaxed by making read_cumul_idx = number of links.
|
||||
//const int16_t loop_max = read_cumul_idx;
|
||||
const int16_t loop_max = nlinks - 1;
|
||||
for (int j = loop_max; j > -1; j--) {
|
||||
if (linkChainMap[read_cumul_idx * nlinks + j] == 0.0) {
|
||||
continue;
|
||||
}
|
||||
int8_t axis_sign = 1;
|
||||
|
||||
int j_type = jointMapType[j];
|
||||
if(j_type != FIXED)
|
||||
{
|
||||
update_joint_type_direction(j_type, axis_sign);
|
||||
}
|
||||
if (j_type == Z_ROT) {
|
||||
float result = 0.0;
|
||||
z_rot_backward_translation(&cumul_mat[matAddrBase + j * 16],
|
||||
&spheres_mem[0], loc_grad_sphere, result);
|
||||
&spheres_mem[0], loc_grad_sphere, result, axis_sign);
|
||||
psum_grad[jointMap[j]] += (psum_t)result;
|
||||
} else if (j_type >= X_PRISM && j_type <= Z_PRISM) {
|
||||
float result = 0.0;
|
||||
xyz_prism_backward_translation(&cumul_mat[matAddrBase + j * 16],
|
||||
loc_grad_sphere, result, j_type);
|
||||
loc_grad_sphere, result, j_type, axis_sign);
|
||||
psum_grad[jointMap[j]] += (psum_t)result;
|
||||
} else if (j_type == X_ROT) {
|
||||
float result = 0.0;
|
||||
x_rot_backward_translation(&cumul_mat[matAddrBase + j * 16],
|
||||
&spheres_mem[0], loc_grad_sphere, result);
|
||||
&spheres_mem[0], loc_grad_sphere, result, axis_sign);
|
||||
psum_grad[jointMap[j]] += (psum_t)result;
|
||||
} else if (j_type == Y_ROT) {
|
||||
float result = 0.0;
|
||||
y_rot_backward_translation(&cumul_mat[matAddrBase + j * 16],
|
||||
&spheres_mem[0], loc_grad_sphere, result);
|
||||
&spheres_mem[0], loc_grad_sphere, result, axis_sign);
|
||||
psum_grad[jointMap[j]] += (psum_t)result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Instead of accumulating the sphere_grad and link_grad separately, we will
|
||||
// accumulate them together once below.
|
||||
//
|
||||
@@ -867,7 +882,7 @@ __global__ void kin_fused_backward_kernel3(
|
||||
float3 g_position = *(float3 *)&grad_nlinks_pos[batchAddrs * 3 + i * 3];
|
||||
float4 g_orientation_t =
|
||||
*(float4 *)&grad_nlinks_quat[batchAddrs * 4 + i * 4];
|
||||
// TODO: sparisty check here:
|
||||
// sparisty check here:
|
||||
if (enable_sparsity_opt) {
|
||||
if (g_position.x == 0 && g_position.y == 0 && g_position.z == 0 &&
|
||||
g_orientation_t.y == 0 && g_orientation_t.z == 0 &&
|
||||
@@ -879,6 +894,7 @@ __global__ void kin_fused_backward_kernel3(
|
||||
make_float3(g_orientation_t.y, g_orientation_t.z, g_orientation_t.w);
|
||||
|
||||
const int16_t l_map = storeLinkMap[i];
|
||||
|
||||
|
||||
float l_pos[3];
|
||||
const int outAddrStart = matAddrBase + l_map * M * M;
|
||||
@@ -886,34 +902,47 @@ __global__ void kin_fused_backward_kernel3(
|
||||
l_pos[1] = cumul_mat[outAddrStart + M + 3];
|
||||
l_pos[2] = cumul_mat[outAddrStart + M * 2 + 3];
|
||||
|
||||
int16_t joints_per_thread = (l_map + 15) / 16;
|
||||
for (int16_t k = joints_per_thread; k >= 0; k--) {
|
||||
int16_t j = k * M + elem_idx;
|
||||
if (j > l_map || j < 0)
|
||||
const int16_t max_lmap = nlinks - 1;
|
||||
|
||||
const int16_t joints_per_thread = (max_lmap + 15) / 16;
|
||||
//for (int16_t k = joints_per_thread; k >= 0; k--)
|
||||
for (int16_t k=0; k < joints_per_thread; k++)
|
||||
{
|
||||
|
||||
//int16_t j = elem_idx * joints_per_thread + k;
|
||||
int16_t j = elem_idx + k * 16;
|
||||
//int16_t j = k * M + elem_idx;
|
||||
if (j > max_lmap || j < 0)
|
||||
continue;
|
||||
// This can be spread across threads as they are not sequential?
|
||||
if (linkChainMap[l_map * nlinks + j] == 0.0) {
|
||||
continue;
|
||||
}
|
||||
int16_t j_idx = jointMap[j];
|
||||
int8_t j_type = jointMapType[j];
|
||||
int j_type = jointMapType[j];
|
||||
|
||||
int8_t axis_sign = 1;
|
||||
if (j_type != FIXED)
|
||||
{
|
||||
update_joint_type_direction(j_type, axis_sign);
|
||||
}
|
||||
|
||||
// get rotation vector:
|
||||
if (j_type == Z_ROT) {
|
||||
z_rot_backward(&cumul_mat[matAddrBase + (j)*M * M], &l_pos[0],
|
||||
g_position, g_orientation, psum_grad[j_idx]);
|
||||
g_position, g_orientation, psum_grad[j_idx], axis_sign);
|
||||
} else if (j_type >= X_PRISM & j_type <= Z_PRISM) {
|
||||
xyz_prism_backward_translation(&cumul_mat[matAddrBase + j * 16],
|
||||
g_position, psum_grad[j_idx], j_type);
|
||||
g_position, psum_grad[j_idx], j_type, axis_sign);
|
||||
} else if (j_type == X_ROT) {
|
||||
x_rot_backward(&cumul_mat[matAddrBase + (j)*M * M], &l_pos[0],
|
||||
g_position, g_orientation, psum_grad[j_idx]);
|
||||
g_position, g_orientation, psum_grad[j_idx], axis_sign);
|
||||
} else if (j_type == Y_ROT) {
|
||||
y_rot_backward(&cumul_mat[matAddrBase + (j)*M * M], &l_pos[0],
|
||||
g_position, g_orientation, psum_grad[j_idx]);
|
||||
g_position, g_orientation, psum_grad[j_idx], axis_sign);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (PARALLEL_WRITE) {
|
||||
|
||||
// accumulate the partial sums across the 16 threads
|
||||
@@ -931,7 +960,8 @@ __global__ void kin_fused_backward_kernel3(
|
||||
|
||||
#pragma unroll
|
||||
for (int16_t j = 0; j < joints_per_thread; j++) {
|
||||
const int16_t j_idx = elem_idx * joints_per_thread + j;
|
||||
//const int16_t j_idx = elem_idx * joints_per_thread + j;
|
||||
const int16_t j_idx = elem_idx + j * 16;
|
||||
if (j_idx >= njoints) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -441,7 +441,7 @@ __global__ void reduce_kernel(
|
||||
rho_buffer[threadIdx.x * batchsize + batch] = rho;
|
||||
}
|
||||
}
|
||||
template <typename scalar_t, typename psum_t>
|
||||
template <typename scalar_t, typename psum_t, bool rolled_ys>
|
||||
__global__ void lbfgs_update_buffer_and_step_v1(
|
||||
scalar_t *step_vec, // b x 175
|
||||
scalar_t *rho_buffer, // m x b x 1
|
||||
@@ -452,7 +452,6 @@ __global__ void lbfgs_update_buffer_and_step_v1(
|
||||
scalar_t *grad_0, // b x 175
|
||||
const scalar_t *grad_q, // b x 175
|
||||
const float epsilon, const int batchsize, const int m, const int v_dim,
|
||||
const bool rolled_ys = false,
|
||||
const bool stable_mode =
|
||||
false) // s_buffer and y_buffer are not rolled by default
|
||||
{
|
||||
@@ -485,6 +484,7 @@ __global__ void lbfgs_update_buffer_and_step_v1(
|
||||
scalar_t s =
|
||||
q[batch * v_dim + threadIdx.x] - x_0[batch * v_dim + threadIdx.x];
|
||||
reduce_v1(y * s, v_dim, &data[0], &result);
|
||||
//reduce(y * s, v_dim, &data[0], &result);
|
||||
scalar_t numerator = result;
|
||||
// scalar_t rho = 1.0/numerator;
|
||||
|
||||
@@ -824,14 +824,14 @@ lbfgs_cuda_fuse(torch::Tensor step_vec, torch::Tensor rho_buffer,
|
||||
y_buffer.scalar_type(), "lbfgs_cuda_fuse_kernel_v1", [&] {
|
||||
smemsize = 3 * m * threadsPerBlock * sizeof(scalar_t) +
|
||||
m * batch_size * sizeof(scalar_t);
|
||||
lbfgs_update_buffer_and_step_v1<scalar_t, scalar_t>
|
||||
lbfgs_update_buffer_and_step_v1<scalar_t, scalar_t, false>
|
||||
<<<blocksPerGrid, threadsPerBlock, smemsize, stream>>>(
|
||||
step_vec.data_ptr<scalar_t>(),
|
||||
rho_buffer.data_ptr<scalar_t>(),
|
||||
y_buffer.data_ptr<scalar_t>(), s_buffer.data_ptr<scalar_t>(),
|
||||
q.data_ptr<scalar_t>(), x_0.data_ptr<scalar_t>(),
|
||||
grad_0.data_ptr<scalar_t>(), grad_q.data_ptr<scalar_t>(),
|
||||
epsilon, batch_size, m, v_dim, false, stable_mode);
|
||||
epsilon, batch_size, m, v_dim, stable_mode);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -164,10 +164,14 @@ compute_distance(float *distance_vec, float &distance, float &position_distance,
|
||||
distance = 0;
|
||||
if (rotation_distance > vec_convergence[0] * vec_convergence[0]) {
|
||||
rotation_distance = sqrtf(rotation_distance);
|
||||
//rotation_distance -= vec_convergence[0];
|
||||
|
||||
distance += rotation_weight * rotation_distance;
|
||||
}
|
||||
if (position_distance > vec_convergence[1] * vec_convergence[1]) {
|
||||
position_distance = sqrtf(position_distance);
|
||||
//position_distance -= vec_convergence[1];
|
||||
|
||||
distance += position_weight * position_distance;
|
||||
}
|
||||
}
|
||||
@@ -202,9 +206,12 @@ __device__ __forceinline__ void compute_metric_distance(
|
||||
distance = 0;
|
||||
if (rotation_distance > vec_convergence[0] * vec_convergence[0]) {
|
||||
rotation_distance = sqrtf(rotation_distance);
|
||||
//rotation_distance -= vec_convergence[0];
|
||||
|
||||
distance += rotation_weight * log2f(coshf(r_alpha * rotation_distance));
|
||||
}
|
||||
if (position_distance > vec_convergence[1] * vec_convergence[1]) {
|
||||
//position_distance -= vec_convergence[1];
|
||||
position_distance = sqrtf(position_distance);
|
||||
distance += position_weight * log2f(coshf(p_alpha * position_distance));
|
||||
}
|
||||
@@ -372,6 +379,14 @@ __global__ void goalset_pose_distance_kernel(
|
||||
// write out pose distance:
|
||||
out_distance[batch_idx * horizon + h_idx] = best_distance;
|
||||
if (write_distance) {
|
||||
if(position_weight == 0.0)
|
||||
{
|
||||
best_position_distance = 0.0;
|
||||
}
|
||||
if (rotation_weight == 0.0)
|
||||
{
|
||||
best_rotation_distance = 0.0;
|
||||
}
|
||||
out_position_distance[batch_idx * horizon + h_idx] = best_position_distance;
|
||||
out_rotation_distance[batch_idx * horizon + h_idx] = best_rotation_distance;
|
||||
}
|
||||
@@ -522,6 +537,14 @@ __global__ void goalset_pose_distance_metric_kernel(
|
||||
// write out pose distance:
|
||||
out_distance[batch_idx * horizon + h_idx] = best_distance;
|
||||
if (write_distance) {
|
||||
if(position_weight == 0.0)
|
||||
{
|
||||
best_position_distance = 0.0;
|
||||
}
|
||||
if (rotation_weight == 0.0)
|
||||
{
|
||||
best_rotation_distance = 0.0;
|
||||
}
|
||||
out_position_distance[batch_idx * horizon + h_idx] = best_position_distance;
|
||||
out_rotation_distance[batch_idx * horizon + h_idx] = best_rotation_distance;
|
||||
}
|
||||
|
||||
@@ -79,7 +79,7 @@ std::vector<torch::Tensor> step_position_clique_wrapper(
|
||||
const torch::Tensor traj_dt, const int batch_size, const int horizon,
|
||||
const int dof) {
|
||||
const at::cuda::OptionalCUDAGuard guard(u_position.device());
|
||||
|
||||
assert(false); // not supported
|
||||
CHECK_INPUT(u_position);
|
||||
CHECK_INPUT(out_position);
|
||||
CHECK_INPUT(out_velocity);
|
||||
@@ -155,7 +155,7 @@ std::vector<torch::Tensor> backward_step_position_clique_wrapper(
|
||||
const torch::Tensor grad_jerk, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof) {
|
||||
const at::cuda::OptionalCUDAGuard guard(grad_position.device());
|
||||
|
||||
assert(false); // not supported
|
||||
CHECK_INPUT(out_grad_position);
|
||||
CHECK_INPUT(grad_position);
|
||||
CHECK_INPUT(grad_velocity);
|
||||
|
||||
@@ -319,14 +319,16 @@ __device__ __forceinline__ void compute_central_difference(scalar_t *out_positio
|
||||
|
||||
const float dt = traj_dt[0]; // assume same dt across traj TODO: Implement variable dt
|
||||
// dt here is actually 1/dt;
|
||||
|
||||
const float dt_inv = 1.0 / dt;
|
||||
const float st_jerk = 0.0; // Note: start jerk can also be passed from global memory
|
||||
// read start state:
|
||||
float out_pos=0.0, out_vel=0.0, out_acc=0.0, out_jerk=0.0;
|
||||
float st_pos=0.0, st_vel=0.0, st_acc = 0.0;
|
||||
|
||||
const int b_addrs = b_idx * horizon * dof;
|
||||
const int b_addrs_action = b_idx * (horizon-4) * dof;
|
||||
float in_pos[5]; // create a 5 value scalar
|
||||
|
||||
const float acc_scale = 1.0;
|
||||
#pragma unroll 5
|
||||
for (int i=0; i< 5; i ++){
|
||||
in_pos[i] = 0.0;
|
||||
@@ -337,92 +339,108 @@ __device__ __forceinline__ void compute_central_difference(scalar_t *out_positio
|
||||
st_acc = start_acceleration[b_offset * dof + d_idx];
|
||||
}
|
||||
|
||||
if (h_idx > 3 && h_idx < horizon - 4)
|
||||
{
|
||||
in_pos[0] = u_position[b_addrs + (h_idx - 4) * dof + d_idx];
|
||||
in_pos[1] = u_position[b_addrs + (h_idx - 3) * dof + d_idx];
|
||||
in_pos[2] = u_position[b_addrs + (h_idx - 2 ) * dof + d_idx];
|
||||
in_pos[3] = u_position[b_addrs + (h_idx - 1) * dof + d_idx];
|
||||
in_pos[4] = u_position[b_addrs + (h_idx) * dof + d_idx];
|
||||
if (h_idx > 3 && h_idx < horizon - 4){
|
||||
in_pos[0] = u_position[b_addrs_action + (h_idx - 4) * dof + d_idx];
|
||||
in_pos[1] = u_position[b_addrs_action + (h_idx - 3) * dof + d_idx];
|
||||
in_pos[2] = u_position[b_addrs_action + (h_idx - 2 ) * dof + d_idx];
|
||||
in_pos[3] = u_position[b_addrs_action + (h_idx - 1) * dof + d_idx];
|
||||
in_pos[4] = u_position[b_addrs_action + (h_idx) * dof + d_idx];
|
||||
|
||||
}
|
||||
|
||||
|
||||
else if (h_idx == 0)
|
||||
{
|
||||
in_pos[0] = st_pos - 3 * dt * ( st_vel + (0.5 * st_acc * dt)); // start -1, start, u0, u1
|
||||
in_pos[1] = st_pos - 2 * dt * ( st_vel + (0.5 * st_acc * dt));
|
||||
in_pos[2] = st_pos - dt * ( st_vel + (0.5 * st_acc * dt));
|
||||
|
||||
|
||||
|
||||
|
||||
in_pos[0] = (3.0f/2) * ( - 1 * st_acc * (dt_inv * dt_inv) - (dt_inv * dt_inv * dt_inv) * st_jerk ) - 3.0f * dt_inv * st_vel + st_pos;
|
||||
in_pos[1] = -2.0f * st_acc * dt_inv * dt_inv - (4.0f/3) * dt_inv * dt_inv * dt_inv * st_jerk - 2.0 * dt_inv * st_vel + st_pos;
|
||||
in_pos[2] = -(3.0f/2) * st_acc * dt_inv * dt_inv - (7.0f/6) * dt_inv * dt_inv * dt_inv * st_jerk - dt_inv * st_vel + st_pos;
|
||||
in_pos[3] = st_pos;
|
||||
in_pos[4] = u_position[b_addrs + (h_idx) * dof + d_idx];
|
||||
in_pos[4] = u_position[b_addrs_action + (h_idx) * dof + d_idx];
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
else if (h_idx == 1)
|
||||
{
|
||||
in_pos[0] = st_pos - 2 * dt * ( st_vel + (0.5 * st_acc * dt)); // start -1, start, u0, u1
|
||||
in_pos[1] = st_pos - dt * ( st_vel + (0.5 * st_acc * dt));
|
||||
|
||||
in_pos[0] = -2.0f * st_acc * dt_inv * dt_inv - (4.0f/3) * dt_inv * dt_inv * dt_inv * st_jerk - 2.0 * dt_inv * st_vel + st_pos;
|
||||
in_pos[1] = -(3.0f/2) * st_acc * dt_inv * dt_inv - (7.0f/6) * dt_inv * dt_inv * dt_inv * st_jerk - dt_inv * st_vel + st_pos;
|
||||
|
||||
|
||||
in_pos[2] = st_pos;
|
||||
in_pos[3] = u_position[b_addrs + (h_idx - 1) * dof + d_idx];
|
||||
in_pos[4] = u_position[b_addrs + (h_idx) * dof + d_idx];
|
||||
in_pos[3] = u_position[b_addrs_action + (h_idx - 1) * dof + d_idx];
|
||||
in_pos[4] = u_position[b_addrs_action + (h_idx) * dof + d_idx];
|
||||
|
||||
|
||||
}
|
||||
|
||||
else if (h_idx == 2)
|
||||
{
|
||||
in_pos[0] = st_pos - dt * ( st_vel + (0.5 * st_acc * dt)); // start -1, start, u0, u1
|
||||
in_pos[0] = -(3.0f/2) * st_acc * dt_inv * dt_inv - (7.0f/6) * dt_inv * dt_inv * dt_inv * st_jerk - dt_inv * st_vel + st_pos;
|
||||
in_pos[1] = st_pos;
|
||||
in_pos[2] = u_position[b_addrs + (h_idx - 2) * dof + d_idx];
|
||||
in_pos[3] = u_position[b_addrs + (h_idx - 1) * dof + d_idx];
|
||||
in_pos[4] = u_position[b_addrs + (h_idx) * dof + d_idx];
|
||||
in_pos[2] = u_position[b_addrs_action + (h_idx - 2) * dof + d_idx];
|
||||
in_pos[3] = u_position[b_addrs_action + (h_idx - 1) * dof + d_idx];
|
||||
in_pos[4] = u_position[b_addrs_action + (h_idx) * dof + d_idx];
|
||||
|
||||
|
||||
}
|
||||
else if (h_idx == 3)
|
||||
{
|
||||
in_pos[0] = st_pos;
|
||||
in_pos[1] = u_position[b_addrs + (h_idx - 3) * dof + d_idx];
|
||||
in_pos[2] = u_position[b_addrs + (h_idx - 2 ) * dof + d_idx];
|
||||
in_pos[3] = u_position[b_addrs + (h_idx - 1 ) * dof + d_idx];
|
||||
in_pos[4] = u_position[b_addrs + (h_idx) * dof + d_idx];
|
||||
in_pos[1] = u_position[b_addrs_action + (h_idx - 3) * dof + d_idx];
|
||||
in_pos[2] = u_position[b_addrs_action + (h_idx - 2 ) * dof + d_idx];
|
||||
in_pos[3] = u_position[b_addrs_action + (h_idx - 1 ) * dof + d_idx];
|
||||
in_pos[4] = u_position[b_addrs_action + (h_idx) * dof + d_idx];
|
||||
|
||||
}
|
||||
|
||||
else if (h_idx == horizon - 4)
|
||||
{
|
||||
in_pos[0] = u_position[b_addrs + (h_idx -4) * dof + d_idx];
|
||||
in_pos[1] = u_position[b_addrs + (h_idx - 3) * dof + d_idx];
|
||||
in_pos[2] = u_position[b_addrs + (h_idx - 2 ) * dof + d_idx];
|
||||
in_pos[3] = u_position[b_addrs + (h_idx - 1) * dof + d_idx];
|
||||
in_pos[4] = in_pos[3];//in_pos[3]; //u_position[b_addrs + (h_idx - 1 + 2) * dof + d_idx];
|
||||
in_pos[0] = u_position[b_addrs_action + (h_idx -4) * dof + d_idx];
|
||||
in_pos[1] = u_position[b_addrs_action + (h_idx - 3) * dof + d_idx];
|
||||
in_pos[2] = u_position[b_addrs_action + (h_idx - 2 ) * dof + d_idx];
|
||||
in_pos[3] = u_position[b_addrs_action + (h_idx - 1) * dof + d_idx];
|
||||
in_pos[4] = in_pos[3];//in_pos[3]; //u_position[b_addrs_action + (h_idx - 1 + 2) * dof + d_idx];
|
||||
|
||||
}
|
||||
|
||||
else if (h_idx == horizon - 3)
|
||||
{
|
||||
in_pos[0] = u_position[b_addrs + (h_idx - 4) * dof + d_idx];
|
||||
in_pos[1] = u_position[b_addrs + (h_idx - 3) * dof + d_idx];
|
||||
in_pos[2] = u_position[b_addrs + (h_idx - 2) * dof + d_idx];
|
||||
in_pos[3] = in_pos[2];//u_position[b_addrs + (h_idx - 1 + 1) * dof + d_idx];
|
||||
in_pos[4] = in_pos[2];//in_pos[3]; //u_position[b_addrs + (h_idx - 1 + 2) * dof + d_idx];
|
||||
in_pos[0] = u_position[b_addrs_action + (h_idx - 4) * dof + d_idx];
|
||||
in_pos[1] = u_position[b_addrs_action + (h_idx - 3) * dof + d_idx];
|
||||
in_pos[2] = u_position[b_addrs_action + (h_idx - 2) * dof + d_idx];
|
||||
in_pos[3] = in_pos[2];//u_position[b_addrs_action + (h_idx - 1 + 1) * dof + d_idx];
|
||||
in_pos[4] = in_pos[2];//in_pos[3]; //u_position[b_addrs_action + (h_idx - 1 + 2) * dof + d_idx];
|
||||
|
||||
}
|
||||
else if (h_idx == horizon - 2)
|
||||
{
|
||||
in_pos[0] = u_position[b_addrs + (h_idx - 4) * dof + d_idx];
|
||||
in_pos[1] = u_position[b_addrs + (h_idx - 3) * dof + d_idx];
|
||||
in_pos[0] = u_position[b_addrs_action + (h_idx - 4) * dof + d_idx];
|
||||
in_pos[1] = u_position[b_addrs_action + (h_idx - 3) * dof + d_idx];
|
||||
in_pos[2] = in_pos[1];
|
||||
in_pos[3] = in_pos[1];//u_position[b_addrs + (h_idx - 1 + 1) * dof + d_idx];
|
||||
in_pos[4] = in_pos[1];//in_pos[3]; //u_position[b_addrs + (h_idx - 1 + 2) * dof + d_idx];
|
||||
in_pos[3] = in_pos[1];//u_position[b_addrs_action + (h_idx - 1 + 1) * dof + d_idx];
|
||||
in_pos[4] = in_pos[1];//in_pos[3]; //u_position[b_addrs_action + (h_idx - 1 + 2) * dof + d_idx];
|
||||
|
||||
}
|
||||
|
||||
else if (h_idx == horizon -1)
|
||||
else if (h_idx == horizon - 1)
|
||||
{
|
||||
in_pos[0] = u_position[b_addrs + (h_idx - 4) * dof + d_idx];
|
||||
in_pos[0] = u_position[b_addrs_action + (h_idx - 4) * dof + d_idx];
|
||||
in_pos[1] = in_pos[0];
|
||||
in_pos[2] = in_pos[0];//u_position[b_addrs + (h_idx - 1 ) * dof + d_idx];
|
||||
in_pos[3] = in_pos[0];//u_position[b_addrs + (h_idx - 1 + 1) * dof + d_idx];
|
||||
in_pos[4] = in_pos[0];//in_pos[3]; //u_position[b_addrs + (h_idx - 1 + 2) * dof + d_idx];
|
||||
in_pos[2] = in_pos[0];//u_position[b_addrs_action + (h_idx - 1 ) * dof + d_idx];
|
||||
in_pos[3] = in_pos[0];//u_position[b_addrs_action + (h_idx - 1 + 1) * dof + d_idx];
|
||||
in_pos[4] = in_pos[0];//in_pos[3]; //u_position[b_addrs_action + (h_idx - 1 + 2) * dof + d_idx];
|
||||
}
|
||||
out_pos = in_pos[2];
|
||||
|
||||
// out_vel = (0.5 * in_pos[3] - 0.5 * in_pos[1]) * dt;
|
||||
out_vel = ((0.083333333f) * in_pos[0] - (0.666666667f) * in_pos[1] + (0.666666667f) * in_pos[3] + (-0.083333333f) * in_pos[4]) * dt;
|
||||
|
||||
@@ -693,8 +711,9 @@ __global__ void backward_position_clique_loop_central_difference_kernel2(
|
||||
return;
|
||||
}
|
||||
const int b_addrs = b_idx * horizon * dof;
|
||||
const int b_addrs_action = b_idx * (horizon-4) * dof;
|
||||
|
||||
if (h_idx < 2 || h_idx > horizon - 2)
|
||||
if (h_idx < 2 || h_idx >= horizon - 2)
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -717,7 +736,7 @@ __global__ void backward_position_clique_loop_central_difference_kernel2(
|
||||
g_jerk[i] = 0.0;
|
||||
}
|
||||
|
||||
int hid = h_idx;
|
||||
const int hid = h_idx;
|
||||
|
||||
g_pos[0] = grad_position[b_addrs + (hid)*dof + d_idx];
|
||||
g_pos[1] = 0.0;
|
||||
@@ -745,30 +764,35 @@ __global__ void backward_position_clique_loop_central_difference_kernel2(
|
||||
(0.5f * g_jerk[0] - g_jerk[1] + g_jerk[3] - 0.5f * g_jerk[4]) * dt * dt * dt);
|
||||
|
||||
}
|
||||
else if (hid == horizon -3)
|
||||
else if (hid == horizon - 3)
|
||||
{
|
||||
//The below can cause oscilatory gradient steps.
|
||||
|
||||
/*
|
||||
#pragma unroll
|
||||
for (int i=0; i< 4; i++)
|
||||
for (int i=0; i< 5; i++)
|
||||
{
|
||||
g_vel[i] = grad_velocity[b_addrs + ((hid - 2) + i)*dof + d_idx];
|
||||
g_acc[i] = grad_acceleration[b_addrs + ((hid -2) + i)*dof + d_idx];
|
||||
g_jerk[i] = grad_jerk[b_addrs + ((hid -2) + i)*dof + d_idx];
|
||||
}
|
||||
*/
|
||||
g_pos[1] = grad_position[b_addrs + (hid + 1)*dof + d_idx];
|
||||
g_pos[2] = grad_position[b_addrs + (hid + 2)*dof + d_idx];
|
||||
|
||||
out_grad = (g_pos[0] + g_pos[1] + g_pos[2] +
|
||||
out_grad = (g_pos[0] + g_pos[1] + g_pos[2]);
|
||||
/* +
|
||||
//((0.5) * g_vel[1] + (0.5) * g_vel[2]) * dt +
|
||||
((-0.083333333f) * g_vel[0] + (0.583333333f) * g_vel[1] + (0.583333333f) * g_vel[2] + (-0.083333333f) * g_vel[3]) * dt +
|
||||
((-0.083333333f) * g_acc[0] + (1.25f) * g_acc[1] + (-1.25f) * g_acc[2] + (0.083333333f) * g_acc[3]) * dt * dt +
|
||||
//( g_acc[1] - g_acc[2]) * dt * dt +
|
||||
(0.5f * g_jerk[0] - 0.5f * g_jerk[1] -0.5f * g_jerk[2] + 0.5f * g_jerk[3]) * dt * dt * dt);
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// write out:
|
||||
out_grad_position[b_addrs + (h_idx-2)*dof + d_idx] = out_grad;
|
||||
out_grad_position[b_addrs_action + (h_idx-2)*dof + d_idx] = out_grad;
|
||||
}
|
||||
|
||||
// for MPPI:
|
||||
@@ -1389,6 +1413,8 @@ std::vector<torch::Tensor> step_position_clique2(
|
||||
}
|
||||
else if (mode == CENTRAL_DIFF)
|
||||
{
|
||||
assert(u_position.sizes()[1] == horizon - 4);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
out_position.scalar_type(), "step_position_clique", ([&] {
|
||||
position_clique_loop_kernel2<scalar_t, CENTRAL_DIFF>
|
||||
@@ -1436,7 +1462,7 @@ std::vector<torch::Tensor> step_position_clique2_idx(
|
||||
if (mode == BWD_DIFF)
|
||||
{
|
||||
|
||||
|
||||
assert(false);
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
out_position.scalar_type(), "step_position_clique", ([&] {
|
||||
position_clique_loop_idx_kernel2<scalar_t, BWD_DIFF>
|
||||
@@ -1455,6 +1481,8 @@ std::vector<torch::Tensor> step_position_clique2_idx(
|
||||
|
||||
else if (mode == CENTRAL_DIFF)
|
||||
{
|
||||
assert(u_position.sizes()[1] == horizon - 4);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
out_position.scalar_type(), "step_position_clique", ([&] {
|
||||
position_clique_loop_idx_kernel2<scalar_t, CENTRAL_DIFF>
|
||||
@@ -1538,7 +1566,7 @@ std::vector<torch::Tensor> backward_step_position_clique2(
|
||||
if (mode == BWD_DIFF)
|
||||
{
|
||||
|
||||
|
||||
assert(false); // not supported anymore
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
out_grad_position.scalar_type(), "backward_step_position_clique", ([&] {
|
||||
backward_position_clique_loop_backward_difference_kernel2<scalar_t>
|
||||
@@ -1554,7 +1582,7 @@ std::vector<torch::Tensor> backward_step_position_clique2(
|
||||
}
|
||||
else if (mode == CENTRAL_DIFF)
|
||||
{
|
||||
|
||||
assert(out_grad_position.sizes()[1] == horizon - 4);
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
out_grad_position.scalar_type(), "backward_step_position_clique", ([&] {
|
||||
backward_position_clique_loop_central_difference_kernel2<scalar_t>
|
||||
|
||||
Reference in New Issue
Block a user