check start state validity and minor fixes

This commit is contained in:
Balakumar Sundaralingam
2024-05-14 22:16:12 -07:00
parent 7196be75f5
commit 911da8cb24
8 changed files with 175 additions and 44 deletions

View File

@@ -82,7 +82,7 @@ namespace Curobo
if (write_grad)
{
dist_vec = (sph1 - sph2) / distance;
dist_vec = normalize(sph1 - sph2);// / distance;
}
}
}
@@ -235,7 +235,7 @@ namespace Curobo
uint64_t nd = __shfl_down_sync(mask, *(uint64_t *)&max_d, offset);
dist_t d_temp = *(dist_t *)&nd;
if (d_temp.d > max_d.d)
if (((threadIdx.x + offset) < blockDim.x) && d_temp.d > max_d.d)
{
max_d = d_temp;
}
@@ -285,11 +285,12 @@ namespace Curobo
if (write_grad)
{
// NOTE: spheres can be read from rs_shared
float3 sph1 =
*(float3 *)&robot_spheres[4 * (batch_idx * nspheres + max_d.i)];
float3 sph2 =
*(float3 *)&robot_spheres[4 * (batch_idx * nspheres + max_d.j)];
float3 dist_vec = (sph1 - sph2) / max_d.d;
float3 dist_vec = normalize(sph1 - sph2);
*(float3 *)&out_vec[batch_idx * nspheres * 4 + max_d.i * 4] =
weight[0] * -1 * dist_vec;
*(float3 *)&out_vec[batch_idx * nspheres * 4 + max_d.j * 4] =
@@ -434,7 +435,7 @@ namespace Curobo
uint64_t nd = __shfl_down_sync(mask, *(uint64_t *)&max_d[l], offset);
dist_t d_temp = *(dist_t *)&nd;
if (d_temp.d > max_d[l].d)
if (((threadIdx.x + offset) < blockDim.x) && d_temp.d > max_d[l].d)
{
max_d[l] = d_temp;
}
@@ -493,13 +494,15 @@ namespace Curobo
if (write_grad)
{
// NOTE: spheres can also be read from rs_shared
float3 sph1 =
*(float3 *)&robot_spheres[4 *
((batch_idx + l) * nspheres + max_d.i)];
float3 sph2 =
*(float3 *)&robot_spheres[4 *
((batch_idx + l) * nspheres + max_d.j)];
float3 dist_vec = (sph1 - sph2) / max_d.d;
float3 dist_vec = normalize(sph1 - sph2);// / max_d.d;
*(float3 *)&out_vec[(batch_idx + l) * nspheres * 4 + max_d.i * 4] =
weight[0] * -1 * dist_vec;
*(float3 *)&out_vec[(batch_idx + l) * nspheres * 4 + max_d.j * 4] =