Improved precision, quality and js planner.

This commit is contained in:
Balakumar Sundaralingam
2024-04-11 13:19:01 -07:00
parent 774dcfd609
commit d6e600c88c
51 changed files with 2128 additions and 1054 deletions

View File

@@ -481,7 +481,7 @@ namespace Curobo
}
delta = make_float3(pt.x - sphere.x, pt.y - sphere.y, pt.z - sphere.z);
@@ -2746,13 +2746,8 @@ sphere_obb_clpt(const torch::Tensor sphere_position, // batch_size, 3
else
{
#if CHECK_FP8
const auto fp8_type = torch::kFloat8_e4m3fn;
#else
const auto fp8_type = torch::kHalf;
#endif
// typename scalar_t, typename dist_scalar_t=float, bool BATCH_ENV_T=true, bool SCALE_METRIC=true, bool SUM_COLLISIONS=true, bool COMPUTE_ESDF=false
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, fp8_type,
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, FP8_TYPE_MACRO,
distance.scalar_type(), "SphereObb_clpt", ([&]{
auto distance_kernel = sphere_obb_distance_kernel<scalar_t, scalar_t, false, scale_metric, sum_collisions_,false>;
if (use_batch_env)
@@ -3211,13 +3206,8 @@ sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
int blocksPerGrid = (bnh_spheres + threadsPerBlock - 1) / threadsPerBlock;
#if CHECK_FP8
const auto fp8_type = torch::kFloat8_e4m3fn;
#else
const auto fp8_type = torch::kHalf;
#endif
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, fp8_type,
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, FP8_TYPE_MACRO,
grid_features.scalar_type(), "SphereVoxel_clpt", ([&]
{
@@ -3255,7 +3245,6 @@ sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
}
}
selected_kernel
<< < blocksPerGrid, threadsPerBlock, 0, stream >> > (
sphere_position.data_ptr<float>(),
@@ -3275,7 +3264,6 @@ sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
num_voxels,
batch_size,
horizon, n_spheres, transform_back);
}));
@@ -3346,13 +3334,7 @@ swept_sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
int blocksPerGrid = (bnh_spheres + threadsPerBlock - 1) / threadsPerBlock;
#if CHECK_FP8
const auto fp8_type = torch::kFloat8_e4m3fn;
#else
const auto fp8_type = torch::kHalf;
#endif
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, fp8_type,
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, FP8_TYPE_MACRO,
grid_features.scalar_type(), "SphereVoxel_clpt", ([&] {
auto collision_kernel_n = swept_sphere_voxel_distance_jump_kernel<scalar_t, float, float, false, scale_metric, true, false, 100>;