Improved precision, quality and js planner.
This commit is contained in:
@@ -481,7 +481,7 @@ namespace Curobo
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
delta = make_float3(pt.x - sphere.x, pt.y - sphere.y, pt.z - sphere.z);
|
||||
@@ -2746,13 +2746,8 @@ sphere_obb_clpt(const torch::Tensor sphere_position, // batch_size, 3
|
||||
else
|
||||
{
|
||||
|
||||
#if CHECK_FP8
|
||||
const auto fp8_type = torch::kFloat8_e4m3fn;
|
||||
#else
|
||||
const auto fp8_type = torch::kHalf;
|
||||
#endif
|
||||
// typename scalar_t, typename dist_scalar_t=float, bool BATCH_ENV_T=true, bool SCALE_METRIC=true, bool SUM_COLLISIONS=true, bool COMPUTE_ESDF=false
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, fp8_type,
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, FP8_TYPE_MACRO,
|
||||
distance.scalar_type(), "SphereObb_clpt", ([&]{
|
||||
auto distance_kernel = sphere_obb_distance_kernel<scalar_t, scalar_t, false, scale_metric, sum_collisions_,false>;
|
||||
if (use_batch_env)
|
||||
@@ -3211,13 +3206,8 @@ sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
|
||||
|
||||
int blocksPerGrid = (bnh_spheres + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
#if CHECK_FP8
|
||||
const auto fp8_type = torch::kFloat8_e4m3fn;
|
||||
#else
|
||||
const auto fp8_type = torch::kHalf;
|
||||
#endif
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, fp8_type,
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, FP8_TYPE_MACRO,
|
||||
grid_features.scalar_type(), "SphereVoxel_clpt", ([&]
|
||||
{
|
||||
|
||||
@@ -3255,7 +3245,6 @@ sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
selected_kernel
|
||||
<< < blocksPerGrid, threadsPerBlock, 0, stream >> > (
|
||||
sphere_position.data_ptr<float>(),
|
||||
@@ -3275,7 +3264,6 @@ sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
|
||||
num_voxels,
|
||||
batch_size,
|
||||
horizon, n_spheres, transform_back);
|
||||
|
||||
}));
|
||||
|
||||
|
||||
@@ -3346,13 +3334,7 @@ swept_sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
|
||||
|
||||
int blocksPerGrid = (bnh_spheres + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
#if CHECK_FP8
|
||||
const auto fp8_type = torch::kFloat8_e4m3fn;
|
||||
#else
|
||||
const auto fp8_type = torch::kHalf;
|
||||
#endif
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, fp8_type,
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(torch::kBFloat16, FP8_TYPE_MACRO,
|
||||
grid_features.scalar_type(), "SphereVoxel_clpt", ([&] {
|
||||
|
||||
auto collision_kernel_n = swept_sphere_voxel_distance_jump_kernel<scalar_t, float, float, false, scale_metric, true, false, 100>;
|
||||
|
||||
Reference in New Issue
Block a user