Add planning to grasp API

This commit is contained in:
Balakumar Sundaralingam
2024-11-22 14:15:18 -08:00
parent 18e9ebd35f
commit 36ea382dab
38 changed files with 939 additions and 535 deletions

View File

@@ -852,7 +852,11 @@ class CudaRobotGenerator(CudaRobotGeneratorConfig):
if not valid_data:
use_experimental_kernel = False
log_warn("Self Collision checks are greater than 32 * 512, using slower kernel")
log_warn(
"Self Collision checks are greater than 32 * 512, using slower kernel."
+ " Number of spheres: "
+ str(self_collision_distance.shape[0])
)
if use_experimental_kernel:
self_coll_matrix = torch.zeros((2), device=self.tensor_args.device, dtype=torch.uint8)
else:

View File

@@ -108,21 +108,21 @@ swept_sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3
torch::Tensor sparsity_idx, const torch::Tensor weight,
const torch::Tensor activation_distance,
const torch::Tensor max_distance,
const torch::Tensor speed_dt,
const torch::Tensor speed_dt,
const torch::Tensor grid_features, // n_boxes, 4, 4
const torch::Tensor grid_params, // n_boxes, 3
const torch::Tensor grid_pose, // n_boxes, 4, 4
const torch::Tensor grid_enable, // n_boxes, 4, 4
const torch::Tensor n_env_grid,
const torch::Tensor env_query_idx, // n_boxes, 4, 4
const int max_nobs,
const int batch_size,
const int max_nobs,
const int batch_size,
const int horizon,
const int n_spheres,
const int n_spheres,
const int sweep_steps,
const bool enable_speed_metric,
const bool transform_back,
const bool compute_distance,
const bool compute_distance,
const bool use_batch_env,
const bool sum_collisions);
@@ -145,14 +145,15 @@ std::vector<torch::Tensor>pose_distance(
const torch::Tensor offset_waypoint,
const torch::Tensor offset_tstep_fraction,
const torch::Tensor batch_pose_idx,
const torch::Tensor project_distance,
const int batch_size,
const int horizon,
const int mode,
const int num_goals = 1,
const bool compute_grad = false,
const bool write_distance = true,
const bool use_metric = false,
const bool project_distance = true);
const bool use_metric = false
);
std::vector<torch::Tensor>
backward_pose_distance(torch::Tensor out_grad_p,
@@ -202,7 +203,7 @@ std::vector<torch::Tensor>sphere_obb_clpt_wrapper(
torch::Tensor closest_point, // batch size, 3
torch::Tensor sparsity_idx, const torch::Tensor weight,
const torch::Tensor activation_distance,
const torch::Tensor max_distance,
const torch::Tensor max_distance,
const torch::Tensor obb_accel, // n_boxes, 4, 4
const torch::Tensor obb_bounds, // n_boxes, 3
const torch::Tensor obb_pose, // n_boxes, 4, 4
@@ -210,9 +211,9 @@ std::vector<torch::Tensor>sphere_obb_clpt_wrapper(
const torch::Tensor n_env_obb, // n_boxes, 4, 4
const torch::Tensor env_query_idx, // n_boxes, 4, 4
const int max_nobs, const int batch_size, const int horizon,
const int n_spheres,
const int n_spheres,
const bool transform_back, const bool compute_distance,
const bool use_batch_env, const bool sum_collisions = true,
const bool use_batch_env, const bool sum_collisions = true,
const bool compute_esdf = false)
{
const at::cuda::OptionalCUDAGuard guard(sphere_position.device());
@@ -305,10 +306,11 @@ std::vector<torch::Tensor>pose_distance_wrapper(
const torch::Tensor weight, const torch::Tensor vec_convergence,
const torch::Tensor run_weight, const torch::Tensor run_vec_weight,
const torch::Tensor offset_waypoint, const torch::Tensor offset_tstep_fraction,
const torch::Tensor batch_pose_idx, const int batch_size, const int horizon,
const torch::Tensor batch_pose_idx,
const torch::Tensor project_distance,
const int batch_size, const int horizon,
const int mode, const int num_goals = 1, const bool compute_grad = false,
const bool write_distance = false, const bool use_metric = false,
const bool project_distance = true)
const bool write_distance = false, const bool use_metric = false)
{
// at::cuda::DeviceGuard guard(angle.device());
CHECK_INPUT(out_distance);
@@ -323,6 +325,7 @@ std::vector<torch::Tensor>pose_distance_wrapper(
CHECK_INPUT(batch_pose_idx);
CHECK_INPUT(offset_waypoint);
CHECK_INPUT(offset_tstep_fraction);
CHECK_INPUT(project_distance);
const at::cuda::OptionalCUDAGuard guard(current_position.device());
return pose_distance(
@@ -332,8 +335,10 @@ std::vector<torch::Tensor>pose_distance_wrapper(
vec_convergence, run_weight, run_vec_weight,
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx, batch_size,
horizon, mode, num_goals, compute_grad, write_distance, use_metric, project_distance);
batch_pose_idx,
project_distance,
batch_size,
horizon, mode, num_goals, compute_grad, write_distance, use_metric);
}
std::vector<torch::Tensor>backward_pose_distance_wrapper(
@@ -372,8 +377,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"Closest Point Voxel(curobolib)");
m.def("swept_closest_point_voxel", &swept_sphere_voxel_clpt,
"Swpet Closest Point Voxel(curobolib)");
m.def("self_collision_distance", &self_collision_distance_wrapper,
"Self Collision Distance (curobolib)");

View File

@@ -146,7 +146,6 @@ namespace Curobo
}
}
template<bool project_distance>
__device__ __forceinline__ void
compute_pose_distance_vector(float *result_vec,
const float3 goal_position,
@@ -156,7 +155,8 @@ namespace Curobo
const float *vec_weight,
const float3 offset_position,
const float3 offset_rotation,
const bool reach_offset)
const bool reach_offset,
const bool project_distance)
{
// project current position to goal frame:
float3 error_position = make_float3(0, 0, 0);
@@ -253,7 +253,7 @@ namespace Curobo
}
}
template<bool project_distance, bool use_metric>
template<bool use_metric>
__device__ __forceinline__ void
compute_pose_distance(float *distance_vec, float& distance, float& position_distance,
float& rotation_distance, const float3 current_position,
@@ -265,17 +265,19 @@ namespace Curobo
const float r_alpha,
const float3 offset_position,
const float3 offset_rotation,
const bool reach_offset)
const bool reach_offset,
const bool project_distance)
{
compute_pose_distance_vector<project_distance>(&distance_vec[0],
goal_position,
goal_quat,
current_position,
current_quat,
&vec_weight[0],
offset_position,
offset_rotation,
reach_offset);
compute_pose_distance_vector(&distance_vec[0],
goal_position,
goal_quat,
current_position,
current_quat,
&vec_weight[0],
offset_position,
offset_rotation,
reach_offset,
project_distance);
position_distance = 0;
rotation_distance = 0;
@@ -394,7 +396,7 @@ namespace Curobo
*(float3 *)&out_grad_q[batch_idx * 4 + 1] = g_q;
}
template<typename scalar_t, bool write_distance, bool project_distance, bool use_metric>
template<typename scalar_t, bool write_distance, bool use_metric>
__global__ void goalset_pose_distance_kernel(
scalar_t *out_distance, scalar_t *out_position_distance,
scalar_t *out_rotation_distance, scalar_t *out_p_vec, scalar_t *out_q_vec,
@@ -405,7 +407,9 @@ namespace Curobo
const scalar_t *run_weight, const scalar_t *run_vec_weight,
const scalar_t *offset_waypoint,
const scalar_t *offset_tstep_fraction,
const int32_t *batch_pose_idx, const int mode, const int num_goals,
const int32_t *batch_pose_idx,
const uint8_t *project_distance_tensor,
const int mode, const int num_goals,
const int batch_size, const int horizon, const bool write_grad = false)
{
const int t_idx = (blockDim.x * blockIdx.x + threadIdx.x);
@@ -416,7 +420,7 @@ namespace Curobo
{
return;
}
const bool project_distance = project_distance_tensor[0];
// read current pose:
float3 position =
*(float3 *)&current_position[batch_idx * horizon * 3 + h_idx * 3];
@@ -511,7 +515,7 @@ namespace Curobo
float4 gq4 = *(float4 *)&goal_quat[(offset + k) * 4];
l_goal_quat = make_float4(gq4.y, gq4.z, gq4.w, gq4.x);
compute_pose_distance<project_distance, use_metric>(&distance_vec[0],
compute_pose_distance<use_metric>(&distance_vec[0],
pose_distance,
position_distance,
rotation_distance,
@@ -531,7 +535,8 @@ namespace Curobo
r_w_alpha,
offset_position,
offset_rotation,
reach_offset);
reach_offset,
project_distance);
if (pose_distance <= best_distance)
{
@@ -657,10 +662,10 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance,
const torch::Tensor offset_waypoint,
const torch::Tensor offset_tstep_fraction,
const torch::Tensor batch_pose_idx, // batch_size, 1
const torch::Tensor project_distance,
const int batch_size, const int horizon, const int mode,
const int num_goals = 1, const bool compute_grad = false,
const bool write_distance = true, const bool use_metric = false,
const bool project_distance = true)
const bool write_distance = true, const bool use_metric = false)
{
using namespace Curobo::Pose;
@@ -684,8 +689,7 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (project_distance)
{
if (use_metric)
{
if (write_distance)
@@ -693,7 +697,7 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance,
AT_DISPATCH_FLOATING_TYPES(
current_position.scalar_type(), "batch_pose_distance", ([&] {
goalset_pose_distance_kernel
<scalar_t, true, true, true><< < blocksPerGrid, threadsPerBlock, 0,
<scalar_t, true, true><< < blocksPerGrid, threadsPerBlock, 0,
stream >> > (
out_distance.data_ptr<scalar_t>(),
out_position_distance.data_ptr<scalar_t>(),
@@ -711,7 +715,9 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance,
run_vec_weight.data_ptr<scalar_t>(),
offset_waypoint.data_ptr<scalar_t>(),
offset_tstep_fraction.data_ptr<scalar_t>(),
batch_pose_idx.data_ptr<int32_t>(), mode, num_goals,
batch_pose_idx.data_ptr<int32_t>(),
project_distance.data_ptr<uint8_t>(),
mode, num_goals,
batch_size, horizon, compute_grad);
}));
}
@@ -720,7 +726,7 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance,
AT_DISPATCH_FLOATING_TYPES(
current_position.scalar_type(), "batch_pose_distance", ([&] {
goalset_pose_distance_kernel
<scalar_t, false, true, true><< < blocksPerGrid, threadsPerBlock, 0,
<scalar_t, false, true><< < blocksPerGrid, threadsPerBlock, 0,
stream >> > (
out_distance.data_ptr<scalar_t>(),
out_position_distance.data_ptr<scalar_t>(),
@@ -738,7 +744,9 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance,
run_vec_weight.data_ptr<scalar_t>(),
offset_waypoint.data_ptr<scalar_t>(),
offset_tstep_fraction.data_ptr<scalar_t>(),
batch_pose_idx.data_ptr<int32_t>(), mode, num_goals,
batch_pose_idx.data_ptr<int32_t>(),
project_distance.data_ptr<uint8_t>(),
mode, num_goals,
batch_size, horizon, compute_grad);
}));
}
@@ -749,7 +757,7 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance,
{
AT_DISPATCH_FLOATING_TYPES(
current_position.scalar_type(), "batch_pose_distance", ([&] {
goalset_pose_distance_kernel<scalar_t, true, true, false>
goalset_pose_distance_kernel<scalar_t, true, false>
<< < blocksPerGrid, threadsPerBlock, 0, stream >> > (
out_distance.data_ptr<scalar_t>(),
out_position_distance.data_ptr<scalar_t>(),
@@ -767,7 +775,9 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance,
run_vec_weight.data_ptr<scalar_t>(),
offset_waypoint.data_ptr<scalar_t>(),
offset_tstep_fraction.data_ptr<scalar_t>(),
batch_pose_idx.data_ptr<int32_t>(), mode, num_goals,
batch_pose_idx.data_ptr<int32_t>(),
project_distance.data_ptr<uint8_t>(),
mode, num_goals,
batch_size, horizon, compute_grad);
}));
}
@@ -775,7 +785,7 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance,
{
AT_DISPATCH_FLOATING_TYPES(
current_position.scalar_type(), "batch_pose_distance", ([&] {
goalset_pose_distance_kernel<scalar_t, false, true, false>
goalset_pose_distance_kernel<scalar_t, false, false>
<< < blocksPerGrid, threadsPerBlock, 0, stream >> > (
out_distance.data_ptr<scalar_t>(),
out_position_distance.data_ptr<scalar_t>(),
@@ -793,127 +803,15 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance,
run_vec_weight.data_ptr<scalar_t>(),
offset_waypoint.data_ptr<scalar_t>(),
offset_tstep_fraction.data_ptr<scalar_t>(),
batch_pose_idx.data_ptr<int32_t>(), mode, num_goals,
batch_pose_idx.data_ptr<int32_t>(),
project_distance.data_ptr<uint8_t>(),
mode, num_goals,
batch_size, horizon, compute_grad);
}));
}
}
}
else
{
if (use_metric)
{
if (write_distance)
{
AT_DISPATCH_FLOATING_TYPES(
current_position.scalar_type(), "batch_pose_distance", ([&] {
goalset_pose_distance_kernel
<scalar_t, true, false, true><< < blocksPerGrid, threadsPerBlock, 0,
stream >> > (
out_distance.data_ptr<scalar_t>(),
out_position_distance.data_ptr<scalar_t>(),
out_rotation_distance.data_ptr<scalar_t>(),
distance_p_vector.data_ptr<scalar_t>(),
distance_q_vector.data_ptr<scalar_t>(),
out_gidx.data_ptr<int32_t>(),
current_position.data_ptr<scalar_t>(),
goal_position.data_ptr<scalar_t>(),
current_quat.data_ptr<scalar_t>(),
goal_quat.data_ptr<scalar_t>(),
vec_weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
vec_convergence.data_ptr<scalar_t>(),
run_weight.data_ptr<scalar_t>(),
run_vec_weight.data_ptr<scalar_t>(),
offset_waypoint.data_ptr<scalar_t>(),
offset_tstep_fraction.data_ptr<scalar_t>(),
batch_pose_idx.data_ptr<int32_t>(), mode, num_goals,
batch_size, horizon, compute_grad);
}));
}
else
{
AT_DISPATCH_FLOATING_TYPES(
current_position.scalar_type(), "batch_pose_distance", ([&] {
goalset_pose_distance_kernel
<scalar_t, false, false, true><< < blocksPerGrid, threadsPerBlock, 0,
stream >> > (
out_distance.data_ptr<scalar_t>(),
out_position_distance.data_ptr<scalar_t>(),
out_rotation_distance.data_ptr<scalar_t>(),
distance_p_vector.data_ptr<scalar_t>(),
distance_q_vector.data_ptr<scalar_t>(),
out_gidx.data_ptr<int32_t>(),
current_position.data_ptr<scalar_t>(),
goal_position.data_ptr<scalar_t>(),
current_quat.data_ptr<scalar_t>(),
goal_quat.data_ptr<scalar_t>(),
vec_weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
vec_convergence.data_ptr<scalar_t>(),
run_weight.data_ptr<scalar_t>(),
run_vec_weight.data_ptr<scalar_t>(),
offset_waypoint.data_ptr<scalar_t>(),
offset_tstep_fraction.data_ptr<scalar_t>(),
batch_pose_idx.data_ptr<int32_t>(), mode, num_goals,
batch_size, horizon, compute_grad);
}));
}
}
else
{
if (write_distance)
{
AT_DISPATCH_FLOATING_TYPES(
current_position.scalar_type(), "batch_pose_distance", ([&] {
goalset_pose_distance_kernel<scalar_t, true, false, false>
<< < blocksPerGrid, threadsPerBlock, 0, stream >> > (
out_distance.data_ptr<scalar_t>(),
out_position_distance.data_ptr<scalar_t>(),
out_rotation_distance.data_ptr<scalar_t>(),
distance_p_vector.data_ptr<scalar_t>(),
distance_q_vector.data_ptr<scalar_t>(),
out_gidx.data_ptr<int32_t>(),
current_position.data_ptr<scalar_t>(),
goal_position.data_ptr<scalar_t>(),
current_quat.data_ptr<scalar_t>(),
goal_quat.data_ptr<scalar_t>(),
vec_weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
vec_convergence.data_ptr<scalar_t>(),
run_weight.data_ptr<scalar_t>(),
run_vec_weight.data_ptr<scalar_t>(),
offset_waypoint.data_ptr<scalar_t>(),
offset_tstep_fraction.data_ptr<scalar_t>(),
batch_pose_idx.data_ptr<int32_t>(), mode, num_goals,
batch_size, horizon, compute_grad);
}));
}
else
{
AT_DISPATCH_FLOATING_TYPES(
current_position.scalar_type(), "batch_pose_distance", ([&] {
goalset_pose_distance_kernel<scalar_t, false, false, false>
<< < blocksPerGrid, threadsPerBlock, 0, stream >> > (
out_distance.data_ptr<scalar_t>(),
out_position_distance.data_ptr<scalar_t>(),
out_rotation_distance.data_ptr<scalar_t>(),
distance_p_vector.data_ptr<scalar_t>(),
distance_q_vector.data_ptr<scalar_t>(),
out_gidx.data_ptr<int32_t>(),
current_position.data_ptr<scalar_t>(),
goal_position.data_ptr<scalar_t>(),
current_quat.data_ptr<scalar_t>(),
goal_quat.data_ptr<scalar_t>(),
vec_weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
vec_convergence.data_ptr<scalar_t>(),
run_weight.data_ptr<scalar_t>(),
run_vec_weight.data_ptr<scalar_t>(),
offset_waypoint.data_ptr<scalar_t>(),
offset_tstep_fraction.data_ptr<scalar_t>(),
batch_pose_idx.data_ptr<int32_t>(), mode, num_goals,
batch_size, horizon, compute_grad);
}));
}
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
return { out_distance, out_position_distance, out_rotation_distance,

View File

@@ -41,7 +41,9 @@ namespace Curobo
scalar_t *out_distance, // batch x 1
scalar_t *out_vec, // batch x nspheres x 4
const scalar_t *robot_spheres, // batch x nspheres x 4
const scalar_t *collision_threshold, const int batch_size,
const scalar_t *offsets,
const uint8_t *coll_matrix,
const int batch_size,
const int nspheres, const scalar_t *weight, const bool write_grad = false)
{
const int batch_idx = blockDim.x * blockIdx.x + threadIdx.x;
@@ -52,37 +54,35 @@ namespace Curobo
}
float r_diff, distance;
float max_penetration = 0;
float3 sph1, sph2, dist_vec;
float4 sph1, sph2;
int sph1_idx = -1;
int sph2_idx = -1;
// iterate over spheres:
for (int i = 0; i < nspheres; i++)
{
sph1 = *(float3 *)&robot_spheres[batch_idx * nspheres * 4 + i * 4];
sph1 = *(float4 *)&robot_spheres[batch_idx * nspheres * 4 + i * 4];
sph1.w += offsets[i];
for (int j = i + 1; j < nspheres; j++)
{
r_diff = collision_threshold[i * nspheres + j];
if (isinf(r_diff))
if(coll_matrix[i * nspheres + j] == 1)
{
continue;
}
sph2 = *(float3 *)&robot_spheres[batch_idx * nspheres * 4 + j * 4];
sph2 = *(float4 *)&robot_spheres[batch_idx * nspheres * 4 + j * 4];
sph2.w += offsets[j];
// compute sphere distance:
distance = relu(r_diff - length(sph1 - sph2));
// compute sphere distance:
r_diff = sph1.w + sph2.w;
float d = sqrt((sph1.x - sph2.x) * (sph1.x - sph2.x) +
(sph1.y - sph2.y) * (sph1.y - sph2.y) +
(sph1.z - sph2.z) * (sph1.z - sph2.z));
distance = (r_diff - d);
if (distance > max_penetration)
{
max_penetration = distance;
sph1_idx = i;
sph2_idx = j;
if (write_grad)
if (distance > max_penetration)
{
dist_vec = normalize(sph1 - sph2);// / distance;
max_penetration = distance;
sph1_idx = i;
sph2_idx = j;
}
}
}
@@ -95,6 +95,11 @@ namespace Curobo
if (write_grad)
{
float3 sph1_g =
*(float3 *)&robot_spheres[4 * (batch_idx * nspheres + sph1_idx)];
float3 sph2_g =
*(float3 *)&robot_spheres[4 * (batch_idx * nspheres + sph2_idx)];
float3 dist_vec = normalize(sph1_g - sph2_g);
*(float3 *)&out_vec[batch_idx * nspheres * 4 + sph1_idx * 4] =
weight[0] * -1 * dist_vec;
*(float3 *)&out_vec[batch_idx * nspheres * 4 + sph2_idx * 4] =
@@ -131,7 +136,7 @@ namespace Curobo
int i = ndpt * (warp_idx / nwpr); // starting row number for this warp
int j = (warp_idx % nwpr) * 32; // starting column number for this warp
dist_t max_d = {0.0, 0.0, 0.0 };// .d, .i, .j
dist_t max_d = { 0.0, 0, 0 };// .d, .i, .j
__shared__ dist_t max_darr[32];
// Optimization: About 1/3 of the warps will have no work.
@@ -354,7 +359,7 @@ namespace Curobo
// in registers (max_d).
// Each thread computes upto ndpt distances.
//////////////////////////////////////////////////////
dist_t max_d[NBPB] = {{ 0.0, 0.0, 0.0}};
dist_t max_d[NBPB] = {{ 0.0, 0, 0}};
int16_t indices[ndpt * 2];
for (uint8_t i = 0; i < ndpt * 2; i++)
@@ -698,7 +703,7 @@ std::vector<torch::Tensor>self_collision_distance(
}
else
{
assert(false); // only ndpt of 32 or 64 is currently supported.
assert(false);
}
}
@@ -713,6 +718,8 @@ std::vector<torch::Tensor>self_collision_distance(
assert(collision_matrix.size(0) == nspheres * nspheres);
int smemSize = nspheres * sizeof(float4);
if (nspheres < 1024 && threadsPerBlock < 1024)
{
AT_DISPATCH_FLOATING_TYPES(
robot_spheres.scalar_type(), "self_collision_distance", ([&] {
@@ -726,6 +733,30 @@ std::vector<torch::Tensor>self_collision_distance(
ndpt_n, nwpr, weight.data_ptr<scalar_t>(),
sparse_index.data_ptr<uint8_t>(), compute_grad);
}));
}
else
{
threadsPerBlock = batch_size;
if (threadsPerBlock > 128)
{
threadsPerBlock = 128;
}
blocksPerGrid = (batch_size + threadsPerBlock - 1) / threadsPerBlock;
AT_DISPATCH_FLOATING_TYPES(
robot_spheres.scalar_type(), "self_collision_distance", ([&] {
self_collision_distance_kernel<scalar_t>
<< < blocksPerGrid, threadsPerBlock, smemSize, stream >> > (
out_distance.data_ptr<scalar_t>(),
out_vec.data_ptr<scalar_t>(),
robot_spheres.data_ptr<scalar_t>(),
collision_offset.data_ptr<scalar_t>(),
collision_matrix.data_ptr<uint8_t>(),
batch_size, nspheres,
weight.data_ptr<scalar_t>(),
compute_grad);
}));
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();

View File

@@ -52,6 +52,28 @@ namespace Curobo
return max(0.0f, sphere_length(v1, v2) - v1.w - v2.w);
}
__device__ __forceinline__ int3 robust_floor(const float3 f_grid, const float threshold=1e-04)
{
float3 nearest_grid = make_float3(round(f_grid.x), round(f_grid.y), round(f_grid.z));
float3 abs_diff = (f_grid - nearest_grid);
if (abs_diff.x >= threshold)
{
nearest_grid.x = floorf(f_grid.x);
}
if (abs_diff.y >= threshold)
{
nearest_grid.y = floorf(f_grid.y);
}
if (abs_diff.z >= threshold)
{
nearest_grid.z = floorf(f_grid.z);
}
return make_int3(nearest_grid);
}
#if CHECK_FP8
__device__ __forceinline__ float
@@ -487,21 +509,20 @@ namespace Curobo
delta = make_float3(pt.x - sphere.x, pt.y - sphere.y, pt.z - sphere.z);
distance = length(delta);
if (!inside)
if (distance == 0.0)
{
delta = -1.0 * make_float3(pt.x, pt.y, pt.z);
}
if (!inside) // outside
{
distance *= -1.0;
}
else
else // inside
{
delta = -1 * delta;
}
if (distance != 0.0)
{
delta = normalize(delta);
}
delta = normalize(delta);
sph_distance = distance + sphere.w;
//
@@ -685,29 +706,48 @@ float4 &sum_pt)
template<typename grid_scalar_t, bool INTERPOLATION=false>
__device__ __forceinline__ void
compute_voxel_location_params(
compute_voxel_index(
const grid_scalar_t *grid_features,
const float4& loc_grid_params,
const float4& loc_sphere,
int &voxel_idx,
int3 &xyz_loc,
int3 &xyz_grid,
float &interpolated_distance,
int &voxel_idx)
float &interpolated_distance)
{
// convert location to index: can use floor to cast to int.
// to account for negative values, add 0.5 * bounds.
const float3 loc_grid = make_float3(loc_grid_params.x, loc_grid_params.y, loc_grid_params.z);
const float3 loc_grid = make_float3(loc_grid_params.x, loc_grid_params.y, loc_grid_params.z);// - loc_grid_params.w;
const float3 sphere = make_float3(loc_sphere.x, loc_sphere.y, loc_sphere.z);
//xyz_loc = make_int3(floorf((sphere + 0.5 * loc_grid) / loc_grid_params.w));
const float inv_voxel_size = 1.0/loc_grid_params.w;
//xyz_loc = make_int3(sphere * inv_voxel_size) + make_int3(0.5 * loc_grid * inv_voxel_size);
const float inv_voxel_size = 1.0f / loc_grid_params.w;
float3 f_grid = (loc_grid) * inv_voxel_size;
xyz_grid = robust_floor(f_grid) + 1;
xyz_loc = make_int3(((sphere.x + 0.5f * loc_grid.x) * inv_voxel_size),
((sphere.y + 0.5f * loc_grid.y)* inv_voxel_size),
((sphere.z + 0.5f * loc_grid.z) * inv_voxel_size));
// check grid bounds:
// 2 to catch numerical precision errors. 1 can be used when exact.
// We need at least 1 as we
// look at neighbouring voxels for finite difference
const int offset = 2;
if (xyz_loc.x >= xyz_grid.x - offset || xyz_loc.y >= xyz_grid.y - offset || xyz_loc.z >= xyz_grid.z - offset
|| xyz_loc.x <= offset || xyz_loc.y <= offset || xyz_loc.z <= offset
)
{
voxel_idx = -1;
return;
}
xyz_loc = make_int3((sphere + 0.5 * loc_grid) * inv_voxel_size);
//xyz_loc = make_int3(sphere / loc_grid_params.w) + make_int3(floorf(0.5 * loc_grid/loc_grid_params.w));
xyz_grid = make_int3((loc_grid * inv_voxel_size)) + 1;
// find next nearest voxel to current point and then do weighted interpolation:
voxel_idx = xyz_loc.x * xyz_grid.y * xyz_grid.z + xyz_loc.y * xyz_grid.z + xyz_loc.z;
@@ -715,10 +755,6 @@ float4 &sum_pt)
// compute interpolation distance between voxel origin and sphere location:
get_array_value(grid_features, voxel_idx, interpolated_distance);
if(!INTERPOLATION)
{
interpolated_distance += 0.5 * loc_grid_params.w;//max(0.0, (0.3 * loc_grid_params.w) - loc_sphere.w);
}
if(INTERPOLATION)
{
//
@@ -739,41 +775,6 @@ float4 &sum_pt)
}
template<typename grid_scalar_t>
__device__ __forceinline__ void
compute_voxel_index(
const grid_scalar_t *grid_features,
const float4& loc_grid_params,
const float4& loc_sphere,
int &voxel_idx,
int3 &xyz_loc,
int3 &xyz_grid,
float &interpolated_distance)
{
// check if sphere is out of bounds
// loc_grid_params.x contains bounds
float4 local_bounds = 0.5*loc_grid_params - 2*loc_grid_params.w;
if (loc_sphere.x <= -1 * (local_bounds.x) ||
loc_sphere.x >= (local_bounds.x) ||
loc_sphere.y <= -1 * (local_bounds.y) ||
loc_sphere.y >= (local_bounds.y) ||
loc_sphere.z <= -1 * (local_bounds.z) ||
loc_sphere.z >= (local_bounds.z))
{
voxel_idx = -1;
return;
}
compute_voxel_location_params(grid_features, loc_grid_params, loc_sphere, xyz_loc, xyz_grid, interpolated_distance, voxel_idx);
// convert location to index: can use floor to cast to int.
// to account for negative values, add 0.5 * bounds.
}
@@ -979,7 +980,7 @@ float4 &sum_pt)
// Load sphere_position input
float4 sphere_cache = *(float4 *)&sphere_position[bn_sph_idx * 4];
if (sphere_cache.w <= 0.0)
if (sphere_cache.w < 0.0)
{
// write zeros for cost:
out_distance[bn_sph_idx] = 0;
@@ -1044,7 +1045,7 @@ float4 &sum_pt)
// Load sphere_position input
float4 sphere_cache = *(float4 *)&sphere_position[bn_sph_idx * 4];
if (sphere_cache.w <= 0.0)
if (sphere_cache.w < 0.0)
{
// write zeros for cost:
out_distance[bn_sph_idx] = 0;
@@ -1173,7 +1174,7 @@ float4 &sum_pt)
// Load sphere_position input
float4 sphere_cache = *(float4 *)&sphere_position[bn_sph_idx * 4];
if (sphere_cache.w <= 0.0)
if (sphere_cache.w < 0.0)
{
// write zeros for cost:
out_distance[bn_sph_idx] = 0;
@@ -1275,7 +1276,7 @@ float4 &sum_pt)
// Load sphere_position input
float4 sphere_cache = *(float4 *)&sphere_position[bn_sph_idx * 4];
if (sphere_cache.w <= 0.0)
if (sphere_cache.w < 0.0)
{
// write zeros for cost:
out_distance[bn_sph_idx] = 0;
@@ -1452,7 +1453,7 @@ float4 &sum_pt)
// Load sphere_position input
float4 sphere_1_cache = *(float4 *)&sphere_position[bhs_idx * 4];
if (sphere_1_cache.w <= 0.0)
if (sphere_1_cache.w < 0.0)
{
// write zeros for cost:
out_distance[bhs_idx] = 0;
@@ -1888,7 +1889,7 @@ float4 &sum_pt)
// Load sphere_position input
float4 sphere_cache = *(float4 *)&sphere_position[bn_sph_idx * 4];
if (sphere_cache.w <= 0.0)
if (sphere_cache.w < 0.0)
{
// write zeros for cost:
out_distance[bn_sph_idx] = 0;
@@ -2001,7 +2002,7 @@ float4 &sum_pt)
float4 sphere_1_cache = *(float4 *)&sphere_position[bhs_idx * 4];
if (sphere_1_cache.w <= 0.0)
if (sphere_1_cache.w < 0.0)
{
// write zeros for cost:
out_distance[bhs_idx] = 0;
@@ -2303,7 +2304,7 @@ float4 &sum_pt)
// if h_idx == horizon -1, we just read the same index
float4 sphere_1_cache = *(float4 *)&sphere_position[bhs_idx * 4];
if (sphere_1_cache.w <= 0.0)
if (sphere_1_cache.w < 0.0)
{
out_distance[b_addrs + h_idx * nspheres + sph_idx] = 0.0;
return;

View File

@@ -157,6 +157,7 @@ def get_pose_distance(
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
project_distance,
batch_size,
horizon,
mode=1,
@@ -164,7 +165,6 @@ def get_pose_distance(
write_grad=False,
write_distance=False,
use_metric=False,
project_distance=True,
):
if batch_pose_idx.shape[0] != batch_size:
raise ValueError("Index buffer size is different from batch size")
@@ -188,6 +188,7 @@ def get_pose_distance(
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
project_distance,
batch_size,
horizon,
mode,
@@ -195,7 +196,6 @@ def get_pose_distance(
write_grad,
write_distance,
use_metric,
project_distance,
)
out_distance = r[0]
@@ -272,6 +272,7 @@ class PoseErrorDistance(torch.autograd.Function):
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
project_distance,
out_distance,
out_position_distance,
out_rotation_distance,
@@ -284,8 +285,7 @@ class PoseErrorDistance(torch.autograd.Function):
horizon,
mode, # =PoseErrorType.BATCH_GOAL.value,
num_goals,
use_metric, # =False,
project_distance, # =True,
use_metric,
):
# out_distance = current_position[..., 0].detach().clone() * 0.0
# out_position_distance = out_distance.detach().clone()
@@ -322,6 +322,7 @@ class PoseErrorDistance(torch.autograd.Function):
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
project_distance,
batch_size,
horizon,
mode,
@@ -329,7 +330,6 @@ class PoseErrorDistance(torch.autograd.Function):
current_position.requires_grad,
True,
use_metric,
project_distance,
)
ctx.save_for_backward(out_p_vec, out_r_vec, weight, out_p_grad, out_q_grad)
return out_distance, out_position_distance, out_rotation_distance, out_idx # .view(-1,1)
@@ -406,6 +406,7 @@ class PoseError(torch.autograd.Function):
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
project_distance,
out_distance,
out_position_distance,
out_rotation_distance,
@@ -419,7 +420,6 @@ class PoseError(torch.autograd.Function):
mode,
num_goals,
use_metric,
project_distance,
return_loss,
):
"""Compute error in pose
@@ -494,6 +494,7 @@ class PoseError(torch.autograd.Function):
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
project_distance,
batch_size,
horizon,
mode,
@@ -501,7 +502,6 @@ class PoseError(torch.autograd.Function):
current_position.requires_grad,
False,
use_metric,
project_distance,
)
ctx.save_for_backward(out_p_vec, out_r_vec)
return out_distance

View File

@@ -113,7 +113,6 @@ class KinematicsFusedFunction(Function):
@staticmethod
def backward(ctx, grad_out_link_pos, grad_out_link_quat, grad_out_spheres):
grad_joint = None
if ctx.needs_input_grad[4]:
(
joint_seq,
@@ -193,10 +192,14 @@ class KinematicsFusedFunction(Function):
b_size = b_shape[0]
n_spheres = robot_sphere_out.shape[1]
n_joints = angle.shape[-1]
if grad_out.is_contiguous():
grad_out = grad_out.view(-1)
else:
grad_out = grad_out.reshape(-1)
grad_out = grad_out.contiguous()
link_pos_out = link_pos_out.contiguous()
link_quat_out = link_quat_out.contiguous()
# if grad_out.is_contiguous():
# grad_out = grad_out.view(-1)
# else:
# grad_out = grad_out.reshape(-1)
r = kinematics_fused_cu.backward(
grad_out,
link_pos_out,

View File

@@ -58,7 +58,6 @@ def wolfe_line_search(
l1 = g_x.shape[1]
l2 = g_x.shape[2]
r = line_search_cu.line_search(
# m_idx,
best_x,
best_c,
best_grad,
@@ -76,7 +75,6 @@ def wolfe_line_search(
l2,
batchsize,
)
# print("batchsize:" + str(batchsize))
return (r[0], r[1], r[2])

View File

@@ -633,6 +633,7 @@ class WorldCollision(WorldCollisionConfig):
self,
cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]),
voxel_size: float = 0.02,
run_marching_cubes: bool = True,
) -> Mesh:
"""Get a mesh representation of the world obstacles based on occupancy in a bounding box.
@@ -642,19 +643,31 @@ class WorldCollision(WorldCollisionConfig):
Args:
cuboid: Bounding box to get the mesh representation.
voxel_size: Size of the voxels in meters.
run_marching_cubes: Runs marching cubes over occupied voxels to generate a mesh. If
set to False, then all occupied voxels are merged into a mesh and returned.
Returns:
Mesh representation of the world obstacles in the bounding box.
"""
voxels = self.get_voxels_in_bounding_box(cuboid, voxel_size)
# voxels = voxels.cpu().numpy()
# cuboids = [Cuboid(name="c_"+str(x), pose=[voxels[x,0],voxels[x,1],voxels[x,2], 1,0,0,0],
# dims=[voxel_size, voxel_size, voxel_size]) for x in range(voxels.shape[0])]
# mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0]
mesh = Mesh.from_pointcloud(
voxels[:, :3].detach().cpu().numpy(),
pitch=voxel_size * 1.1,
)
voxels = voxels.cpu().numpy()
if run_marching_cubes:
mesh = Mesh.from_pointcloud(
voxels[:, :3].detach().cpu().numpy(),
pitch=voxel_size * 1.1,
)
else:
cuboids = [
Cuboid(
name="c_" + str(x),
pose=[voxels[x, 0], voxels[x, 1], voxels[x, 2], 1, 0, 0, 0],
dims=[voxel_size, voxel_size, voxel_size],
)
for x in range(voxels.shape[0])
]
mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0]
return mesh
def get_obstacle_names(self, env_idx: int = 0) -> List[str]:

View File

@@ -358,12 +358,21 @@ class WorldVoxelCollision(WorldMeshCollision):
env_idx: Environment index to update voxel grid in.
"""
obs_idx = self.get_voxel_idx(new_voxel.name, env_idx)
self._voxel_tensor_list[3][env_idx, obs_idx, :, :] = new_voxel.feature_tensor.view(
new_voxel.feature_tensor.shape[0], -1
).to(dtype=self._voxel_tensor_list[3].dtype)
self._voxel_tensor_list[0][env_idx, obs_idx, :3] = self.tensor_args.to_device(
new_voxel.dims
)
feature_tensor = new_voxel.feature_tensor.view(new_voxel.feature_tensor.shape[0], -1)
if (
feature_tensor.shape[0] != self._voxel_tensor_list[3][env_idx, obs_idx, :, :].shape[0]
or feature_tensor.shape[1]
!= self._voxel_tensor_list[3][env_idx, obs_idx, :, :].shape[1]
):
log_error(
"Feature tensor shape mismatch, existing shape: "
+ str(self._voxel_tensor_list[3][env_idx, obs_idx, :, :].shape)
+ " New shape: "
+ str(feature_tensor.shape)
)
self._voxel_tensor_list[3][env_idx, obs_idx, :, :].copy_(feature_tensor)
self._voxel_tensor_list[0][env_idx, obs_idx, :3].copy_(torch.as_tensor(new_voxel.dims))
self._voxel_tensor_list[0][env_idx, obs_idx, 3] = new_voxel.voxel_size
self._voxel_tensor_list[1][env_idx, obs_idx, :7] = (
Pose.from_list(new_voxel.pose, self.tensor_args).inverse().get_pose_vector()
@@ -876,14 +885,19 @@ class WorldVoxelCollision(WorldMeshCollision):
self._env_n_voxels[:] = 0
super().clear_cache()
def get_voxel_grid_shape(self, env_idx: int = 0, obs_idx: int = 0) -> torch.Size:
def get_voxel_grid_shape(
self, env_idx: int = 0, obs_idx: int = 0, name: Optional[str] = None
) -> torch.Size:
"""Get dimensions of the voxel grid.
Args:
env_idx: Environment index.
obs_idx: Obstacle index.
name: Name of obstacle. When provided, obs_idx is ignored.
Returns:
Shape of the voxel grid.
"""
if name is not None:
obs_idx = self.get_voxel_idx(name, env_idx)
return self._voxel_tensor_list[3][env_idx, obs_idx].shape

View File

@@ -13,7 +13,6 @@
from __future__ import annotations
# Standard Library
import math
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
@@ -28,6 +27,7 @@ from curobo.geom.sphere_fit import SphereFitType, fit_spheres_to_mesh
from curobo.types.base import TensorDeviceType
from curobo.types.camera import CameraObservation
from curobo.types.math import Pose
from curobo.util.helpers import robust_floor
from curobo.util.logger import log_error, log_warn
from curobo.util_file import get_assets_path, join_path
@@ -723,12 +723,16 @@ class VoxelGrid(Obstacle):
"""Get shape of voxel grid."""
bounds = self.dims
grid_shape = [bounds[0], bounds[1], bounds[2]]
inv_voxel_size = 1.0 / self.voxel_size
grid_shape = [1 + robust_floor(x * inv_voxel_size) for x in grid_shape]
low = [-bounds[0] / 2, -bounds[1] / 2, -bounds[2] / 2]
high = [bounds[0] / 2, bounds[1] / 2, bounds[2] / 2]
grid_shape = [
1 + int(high[i] / self.voxel_size) - (int(low[i] / self.voxel_size))
for i in range(len(low))
]
return grid_shape, low, high
def create_xyzr_tensor(
@@ -745,10 +749,19 @@ class VoxelGrid(Obstacle):
"""
trange, low, high = self.get_grid_shape()
x = torch.linspace(low[0], high[0], trange[0], device=tensor_args.device)
y = torch.linspace(low[1], high[1], trange[1], device=tensor_args.device)
z = torch.linspace(low[2], high[2], trange[2], device=tensor_args.device)
inv_voxel_size = 1.0 / self.voxel_size
x = torch.linspace(1, trange[0], trange[0], device=tensor_args.device) - round(
(0.5 * self.dims[0]) * inv_voxel_size
)
y = torch.linspace(1, trange[1], trange[1], device=tensor_args.device) - round(
(0.5 * self.dims[1]) * inv_voxel_size
)
z = torch.linspace(1, trange[2], trange[2], device=tensor_args.device) - round(
(0.5 * self.dims[2]) * inv_voxel_size
)
x = x * self.voxel_size - 0.5 * self.voxel_size
y = y * self.voxel_size - 0.5 * self.voxel_size
z = z * self.voxel_size - 0.5 * self.voxel_size
w, l, h = x.shape[0], y.shape[0], z.shape[0]
xyz = (
torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3)
@@ -757,7 +770,7 @@ class VoxelGrid(Obstacle):
if transform_to_origin:
pose = Pose.from_list(self.pose, tensor_args=tensor_args)
xyz = pose.transform_points(xyz.contiguous())
r = torch.zeros_like(xyz[:, 0:1]) + (self.voxel_size * 0.5)
r = torch.zeros_like(xyz[:, 0:1])
xyzr = torch.cat([xyz, r], dim=1)
return xyzr

View File

@@ -10,19 +10,10 @@
#
# Standard Library
import random
# Third Party
import networkx as nx
import numpy as np
import torch
# This is needed to get deterministic results from networkx.
# Note: it has to be set in global space.
np.random.seed(2)
random.seed(2)
class NetworkxGraph(object):
def __init__(self):
@@ -63,7 +54,11 @@ class NetworkxGraph(object):
def path_exists(self, start_node_idx, goal_node_idx):
self.update_graph()
return nx.has_path(self.graph, start_node_idx, goal_node_idx)
# check if nodes exist in the graph
if self.graph.has_node(start_node_idx) and self.graph.has_node(goal_node_idx):
return nx.has_path(self.graph, start_node_idx, goal_node_idx)
else:
return False
def get_shortest_path(self, start_node_idx, goal_node_idx, return_length=False):
self.update_graph()

View File

@@ -603,13 +603,30 @@ def _wolfe_search_tail_jit(c, g_x, x_set, m, d_opt: int):
@get_torch_jit_decorator()
def scale_action(dx, action_step_max):
def scale_action_old(dx, action_step_max):
scale_value = torch.max(torch.abs(dx) / action_step_max, dim=-1, keepdim=True)[0]
scale_value = torch.clamp(scale_value, 1.0)
dx_scaled = dx / scale_value
return dx_scaled
@get_torch_jit_decorator()
def scale_action(dx, action_step_max):
# get largest dx scaled by bounds across optimization variables
scale_value = torch.max(torch.abs(dx) / action_step_max, dim=-1, keepdim=True)[0]
# scale dx to bring all dx within bounds:
# only perfom for dx that are greater than 1:
new_scale = torch.where(scale_value <= 1.0, 1.0, scale_value)
dx_scaled = dx / new_scale
# scale_value = torch.clamp(scale_value, 1.0)
# dx_scaled = dx / scale_value
return dx_scaled
@get_torch_jit_decorator()
def check_convergence(
best_iteration: torch.Tensor, current_iteration: torch.Tensor, last_best: int

View File

@@ -435,7 +435,14 @@ class ArmReacher(ArmBase, ArmReacherConfig):
self.dist_cost.disable_cost()
self.cspace_convergence.disable_cost()
def get_pose_costs(self, include_link_pose: bool = False, include_convergence: bool = True):
def get_pose_costs(
self,
include_link_pose: bool = False,
include_convergence: bool = True,
only_convergence: bool = False,
):
if only_convergence:
return [self.pose_convergence]
pose_costs = [self.goal_cost]
if include_convergence:
pose_costs += [self.pose_convergence]
@@ -447,33 +454,15 @@ class ArmReacher(ArmBase, ArmReacherConfig):
self,
metric: PoseCostMetric,
):
pose_costs = self.get_pose_costs(include_link_pose=metric.include_link_pose)
if metric.hold_partial_pose:
if metric.hold_vec_weight is None:
log_error("hold_vec_weight is required")
[x.hold_partial_pose(metric.hold_vec_weight) for x in pose_costs]
if metric.release_partial_pose:
[x.release_partial_pose() for x in pose_costs]
if metric.reach_partial_pose:
if metric.reach_vec_weight is None:
log_error("reach_vec_weight is required")
[x.reach_partial_pose(metric.reach_vec_weight) for x in pose_costs]
if metric.reach_full_pose:
[x.reach_full_pose() for x in pose_costs]
pose_costs = self.get_pose_costs(
include_link_pose=metric.include_link_pose, include_convergence=False
)
for p in pose_costs:
p.update_metric(metric, update_offset_waypoint=True)
pose_costs = self.get_pose_costs(include_convergence=False)
if metric.remove_offset_waypoint:
[x.remove_offset_waypoint() for x in pose_costs]
if metric.offset_position is not None or metric.offset_rotation is not None:
[
x.update_offset_waypoint(
offset_position=metric.offset_position,
offset_rotation=metric.offset_rotation,
offset_tstep_fraction=metric.offset_tstep_fraction,
)
for x in pose_costs
]
pose_costs = self.get_pose_costs(only_convergence=True)
for p in pose_costs:
p.update_metric(metric, update_offset_waypoint=False)
@get_torch_jit_decorator()

View File

@@ -86,6 +86,7 @@ class PoseCostMetric:
offset_tstep_fraction: float = -1.0
remove_offset_waypoint: bool = False
include_link_pose: bool = False
project_to_goal_frame: Optional[bool] = None
def clone(self):
@@ -102,6 +103,8 @@ class PoseCostMetric:
offset_rotation=None if self.offset_rotation is None else self.offset_rotation.clone(),
offset_tstep_fraction=self.offset_tstep_fraction,
remove_offset_waypoint=self.remove_offset_waypoint,
include_link_pose=self.include_link_pose,
project_to_goal_frame=self.project_to_goal_frame,
)
@classmethod
@@ -110,6 +113,7 @@ class PoseCostMetric:
offset_position: float = 0.1,
linear_axis: int = 2,
tstep_fraction: float = 0.8,
project_to_goal_frame: Optional[bool] = None,
tensor_args: TensorDeviceType = TensorDeviceType(),
) -> PoseCostMetric:
"""Enables moving to a pregrasp and then locked orientation movement to final grasp.
@@ -121,6 +125,8 @@ class PoseCostMetric:
offset_position: offset in meters.
linear_axis: specifies the x y or z axis.
tstep_fraction: specifies the timestep fraction to start activating this constraint.
project_to_goal_frame: compute distance w.r.t. to goal frame instead of robot base
frame. If None, it will use value set in PoseCostConfig.
tensor_args: cuda device.
Returns:
@@ -150,12 +156,17 @@ class PoseCost(CostBase, PoseCostConfig):
def __init__(self, config: PoseCostConfig):
PoseCostConfig.__init__(self, **vars(config))
CostBase.__init__(self)
self.project_distance_tensor = torch.tensor(
[self.project_distance],
device=self.tensor_args.device,
dtype=torch.uint8,
)
self.rot_weight = self.vec_weight[0:3]
self.pos_weight = self.vec_weight[3:6]
self._vec_convergence = self.tensor_args.to_device(self.vec_convergence)
self._batch_size = 0
def update_metric(self, metric: PoseCostMetric):
def update_metric(self, metric: PoseCostMetric, update_offset_waypoint: bool = True):
if metric.hold_partial_pose:
if metric.hold_vec_weight is None:
log_error("hold_vec_weight is required")
@@ -168,19 +179,22 @@ class PoseCost(CostBase, PoseCostConfig):
self.reach_partial_pose(metric.reach_vec_weight)
if metric.reach_full_pose:
self.reach_full_pose()
if metric.project_to_goal_frame is not None:
self.project_distance_tensor[:] = metric.project_to_goal_frame
else:
self.project_distance_tensor[:] = self.project_distance
if update_offset_waypoint:
if metric.remove_offset_waypoint:
self.remove_offset_waypoint()
if metric.remove_offset_waypoint:
self.remove_offset_waypoint()
if metric.offset_position is not None or metric.offset_rotation is not None:
self.update_offset_waypoint(
offset_position=self.offset_position,
offset_rotation=self.offset_rotation,
offset_tstep_fraction=self.offset_tstep_fraction,
)
if metric.offset_position is not None or metric.offset_rotation is not None:
self.update_offset_waypoint(
offset_position=metric.offset_position,
offset_rotation=metric.offset_rotation,
offset_tstep_fraction=metric.offset_tstep_fraction,
)
def hold_partial_pose(self, run_vec_weight: torch.Tensor):
self.run_vec_weight.copy_(run_vec_weight)
def release_partial_pose(self):
@@ -391,6 +405,7 @@ class PoseCost(CostBase, PoseCostConfig):
self.offset_waypoint,
self.offset_tstep_fraction,
goal.batch_pose_idx,
self.project_distance_tensor,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
@@ -404,7 +419,6 @@ class PoseCost(CostBase, PoseCostConfig):
self.cost_type.value,
num_goals,
self.use_metric,
self.project_distance,
)
# print(self.out_idx.shape, self.out_idx[:,-1])
# print(goal.batch_pose_idx.shape)
@@ -444,6 +458,7 @@ class PoseCost(CostBase, PoseCostConfig):
self.offset_waypoint,
self.offset_tstep_fraction,
goal.batch_pose_idx,
self.project_distance_tensor,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
@@ -457,7 +472,6 @@ class PoseCost(CostBase, PoseCostConfig):
self.cost_type.value,
num_goals,
self.use_metric,
self.project_distance,
self.return_loss,
)
@@ -498,6 +512,7 @@ class PoseCost(CostBase, PoseCostConfig):
self.offset_waypoint,
self.offset_tstep_fraction,
batch_pose_idx,
self.project_distance_tensor,
self.out_distance,
self.out_position_distance,
self.out_rotation_distance,
@@ -511,7 +526,6 @@ class PoseCost(CostBase, PoseCostConfig):
self.cost_type.value,
num_goals,
self.use_metric,
self.project_distance,
self.return_loss,
)
return distance

View File

@@ -449,7 +449,14 @@ class KinematicModel(KinematicModelConfig):
state_seq = self.state_seq
curr_batch_size = self.batch_size
num_traj_points = self.horizon
if not state_seq.position.is_contiguous():
state_seq.position = state_seq.position.contiguous()
if not state_seq.velocity.is_contiguous():
state_seq.velocity = state_seq.velocity.contiguous()
if not state_seq.acceleration.is_contiguous():
state_seq.acceleration = state_seq.acceleration.contiguous()
if not state_seq.jerk.is_contiguous():
state_seq.jerk = state_seq.jerk.contiguous()
with profiler.record_function("tensor_step"):
# forward step with step matrix:
state_seq = self.tensor_step(start_state_shaped, act_seq, state_seq, start_state_idx)

View File

@@ -121,7 +121,7 @@ class CameraObservation:
point_cloud = project_depth_using_rays(depth_image, self.projection_rays)
if project_to_pose and self.pose is not None:
point_cloud = self.pose.batch_transform(point_cloud)
point_cloud = self.pose.batch_transform_points(point_cloud)
return point_cloud

View File

@@ -507,37 +507,25 @@ def angular_distance_phi3(goal_quat, current_quat):
class OrientationError(Function):
@staticmethod
def geodesic_distance(goal_quat, current_quat, quat_res):
conjugate_quat = current_quat.clone()
conjugate_quat[..., 1:] *= -1.0
quat_res = quat_multiply(goal_quat, conjugate_quat, quat_res)
quat_res = -1.0 * quat_res * torch.sign(quat_res[..., 0]).unsqueeze(-1)
quat_res[..., 0] = 0.0
# quat_res = conjugate_quat * 0.0
return quat_res
quat_grad, rot_error = geodesic_distance(goal_quat, current_quat, quat_res)
return quat_grad, rot_error
@staticmethod
def forward(ctx, goal_quat, current_quat, quat_res):
quat_res = OrientationError.geodesic_distance(goal_quat, current_quat, quat_res)
rot_error = torch.norm(quat_res, dim=-1, keepdim=True)
ctx.save_for_backward(quat_res, rot_error)
quat_grad, rot_error = OrientationError.geodesic_distance(goal_quat, current_quat, quat_res)
ctx.save_for_backward(quat_grad)
return rot_error
@staticmethod
def backward(ctx, grad_out):
grad_mul = None
if ctx.needs_input_grad[1]:
(quat_error, r_err) = ctx.saved_tensors
scale = 1 / r_err
scale = torch.nan_to_num(scale, 0, 0, 0)
grad_mul = grad_mul1 = None
(quat_grad,) = ctx.saved_tensors
grad_mul = grad_out * scale * quat_error
# print(grad_out.shape)
# if grad_out.shape[0] == 6:
# #print(grad_out.view(-1))
# #print(grad_mul.view(-1)[-6:])
# #exit()
return None, grad_mul, None
if ctx.needs_input_grad[1]:
grad_mul = grad_out * quat_grad
if ctx.needs_input_grad[0]:
grad_mul1 = -1.0 * grad_out * quat_grad
return grad_mul1, grad_mul, None
@get_torch_jit_decorator()
@@ -549,3 +537,19 @@ def normalize_quaternion(in_quaternion: torch.Tensor) -> torch.Tensor:
# normalize quaternion
in_q = k2 * in_quaternion
return in_q
@get_torch_jit_decorator()
def geodesic_distance(goal_quat, current_quat, quat_res):
conjugate_quat = current_quat.detach().clone()
conjugate_quat[..., 1:] *= -1.0
quat_res = quat_multiply(goal_quat, conjugate_quat, quat_res)
sign = torch.sign(quat_res[..., 0])
sign = torch.where(sign == 0, 1.0, sign)
quat_res = -1.0 * quat_res * sign.unsqueeze(-1)
quat_res[..., 0] = 0.0
rot_error = torch.norm(quat_res, dim=-1, keepdim=True)
scale = 1.0 / rot_error
scale = torch.nan_to_num(scale, 0.0, 0.0, 0.0)
quat_res = quat_res * scale
return quat_res, rot_error

View File

@@ -9,6 +9,7 @@
# its affiliates is strictly prohibited.
#
# Standard Library
import math
from collections import defaultdict
from typing import List
@@ -27,3 +28,11 @@ def list_idx_if_not_none(d_list: List, idx: int):
else:
idx_list.append(None)
return idx_list
def robust_floor(x: float, threshold: float = 1e-04) -> int:
nearest_int = round(x)
if abs(x - nearest_int) < threshold:
return nearest_int
else:
return int(math.floor(x))

View File

@@ -137,7 +137,7 @@ class HaltonSampleLib(BaseSampleLib):
return self.samples
def bspline(c_arr, t_arr=None, n=100, degree=3):
def bspline(c_arr: torch.Tensor, t_arr=None, n=100, degree=3):
sample_device = c_arr.device
sample_dtype = c_arr.dtype
cv = c_arr.cpu().numpy()

View File

@@ -178,3 +178,9 @@ def get_cache_fn_decorator(maxsize: Optional[int] = None):
def empty_decorator(function):
return function
@get_torch_jit_decorator()
def round_away_from_zero(x: torch.Tensor) -> torch.Tensor:
y = torch.trunc(x + 0.5 * torch.sign(x))
return y

View File

@@ -99,7 +99,7 @@ def get_linear_traj(
return trajectory
def get_smooth_trajectory(raw_traj, degree=5):
def get_smooth_trajectory(raw_traj: torch.Tensor, degree: int = 5):
cpu_traj = raw_traj.cpu()
smooth_traj = torch.zeros_like(cpu_traj)
@@ -108,11 +108,10 @@ def get_smooth_trajectory(raw_traj, degree=5):
return smooth_traj.to(dtype=raw_traj.dtype, device=raw_traj.device)
def get_spline_interpolated_trajectory(raw_traj, des_horizon, degree=5):
def get_spline_interpolated_trajectory(raw_traj: torch.Tensor, des_horizon: int, degree: int = 5):
retimed_traj = torch.zeros((des_horizon, raw_traj.shape[-1]))
tensor_args = TensorDeviceType(device=raw_traj.device, dtype=raw_traj.dtype)
cpu_traj = raw_traj.cpu().numpy()
cpu_traj = raw_traj.cpu()
for i in range(cpu_traj.shape[-1]):
retimed_traj[:, i] = bspline(cpu_traj[:, i], n=des_horizon, degree=degree)
retimed_traj = retimed_traj.to(**(tensor_args.as_torch_dict()))
@@ -179,7 +178,7 @@ def get_batch_interpolated_trajectory(
opt_dt[:] = raw_dt
# traj_steps contains the tsteps for each trajectory
if steps_max <= 0:
log_error("Steps max is less than 0")
log_error("Steps max is less than 1, with a value: " + str(steps_max))
if out_traj_state is not None and out_traj_state.position.shape[1] < steps_max:
log_warn(
@@ -610,5 +609,7 @@ def calculate_tsteps(
)
if not optimize_dt:
opt_dt[:] = raw_dt
# check for nan:
opt_dt = torch.nan_to_num(opt_dt, nan=min_dt)
traj_steps, steps_max = calculate_traj_steps(opt_dt, interpolation_dt, horizon)
return traj_steps, steps_max, opt_dt

View File

@@ -8,6 +8,7 @@
# its affiliates is strictly prohibited.
# Standard Library
from copy import deepcopy
from typing import Any, Dict, Optional
# CuRobo
@@ -17,9 +18,13 @@ from curobo.util.logger import log_error, log_warn
from curobo.util_file import load_yaml
def return_value_if_exists(input_dict: Dict, key: str, suffix: str = "xrdf") -> Any:
def return_value_if_exists(
input_dict: Dict, key: str, suffix: str = "xrdf", raise_error: bool = True
) -> Any:
if key not in input_dict:
log_error(key + " key not found in " + suffix)
if raise_error:
log_error(key + " key not found in " + suffix)
return None
return input_dict[key]
@@ -42,7 +47,6 @@ def convert_xrdf_to_curobo(
if return_value_if_exists(input_xrdf_dict, "format") != "xrdf":
log_error("format is not xrdf")
raise ValueError("format is not xrdf")
if return_value_if_exists(input_xrdf_dict, "format_version") > 1.0:
log_warn("format_version is greater than 1.0")
@@ -63,7 +67,11 @@ def convert_xrdf_to_curobo(
coll_spheres = return_value_if_exists(input_xrdf_dict["geometry"][coll_name], "spheres")
output_dict["collision_spheres"] = coll_spheres
buffer_distance = return_value_if_exists(input_xrdf_dict["collision"], "buffer_distance")
buffer_distance = return_value_if_exists(
input_xrdf_dict["collision"], "buffer_distance", raise_error=False
)
if buffer_distance is None:
buffer_distance = 0.0
output_dict["collision_sphere_buffer"] = buffer_distance
output_dict["collision_link_names"] = list(coll_spheres.keys())
@@ -82,8 +90,10 @@ def convert_xrdf_to_curobo(
self_collision_buffer = return_value_if_exists(
input_xrdf_dict["self_collision"],
"buffer_distance",
raise_error=False,
)
if self_collision_buffer is None:
self_collision_buffer = {}
output_dict["self_collision_ignore"] = self_collision_ignore
output_dict["self_collision_buffer"] = self_collision_buffer
else:
@@ -92,10 +102,10 @@ def convert_xrdf_to_curobo(
log_warn("collision key not found in xrdf, collision avoidance is disabled")
tool_frames = return_value_if_exists(input_xrdf_dict, "tool_frames")
output_dict["ee_link"] = tool_frames[0]
output_dict["link_names"] = None
if len(tool_frames) > 1:
output_dict["link_names"] = input_xrdf_dict["tool_frames"]
output_dict["link_names"] = deepcopy(tool_frames)
# cspace:
cspace_dict = return_value_if_exists(input_xrdf_dict, "cspace")

View File

@@ -11,17 +11,33 @@
"""Contains helper functions for interacting with file systems."""
# Standard Library
import os
import re
import shutil
import sys
from typing import Any, Dict, List, Union
# Third Party
import yaml
from yaml import Loader
from yaml import SafeLoader as Loader
# CuRobo
from curobo.util.logger import log_warn
Loader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
"""^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""",
re.X,
),
list("-+0123456789."),
)
# get paths
def get_module_path() -> str:

View File

@@ -242,6 +242,7 @@ class MotionGenConfig:
ik_seed: int = 1531,
graph_seed: int = 1531,
high_precision: bool = False,
use_cuda_graph_trajopt_metrics: bool = False,
):
"""Create a motion generation configuration from robot and world configuration.
@@ -473,6 +474,10 @@ class MotionGenConfig:
the number of iterations for optimization solvers and reduce the thresholds for
position to 1mm and rotation to 0.025. Default of False is recommended for most
cases as standard motion generation settings reach within 0.5mm on most problems.
use_cuda_graph_trajopt_metrics: Flag to enable cuda_graph when evaluating interpolated
trajectories after trajectory optimization. If interpolation_buffer is smaller
than interpolated trajectory, then the buffers will be re-created. This can cause
existing cuda graph to be invalid.
Returns:
MotionGenConfig: Instance of motion generation configuration.
@@ -722,6 +727,7 @@ class MotionGenConfig:
minimize_jerk=minimize_jerk,
optimize_dt=optimize_dt,
project_pose_to_goal_frame=project_pose_to_goal_frame,
use_cuda_graph_metrics=use_cuda_graph_trajopt_metrics,
)
trajopt_solver = TrajOptSolver(trajopt_cfg)
@@ -763,6 +769,7 @@ class MotionGenConfig:
filter_robot_command=filter_robot_command,
optimize_dt=optimize_dt,
num_seeds=num_trajopt_noisy_seeds,
use_cuda_graph_metrics=use_cuda_graph_trajopt_metrics,
)
js_trajopt_solver = TrajOptSolver(js_trajopt_cfg)
@@ -805,6 +812,7 @@ class MotionGenConfig:
filter_robot_command=filter_robot_command,
optimize_dt=optimize_dt,
project_pose_to_goal_frame=project_pose_to_goal_frame,
use_cuda_graph_metrics=use_cuda_graph_trajopt_metrics,
)
finetune_trajopt_solver = TrajOptSolver(finetune_trajopt_cfg)
@@ -847,6 +855,7 @@ class MotionGenConfig:
filter_robot_command=filter_robot_command,
optimize_dt=optimize_dt,
num_seeds=num_trajopt_noisy_seeds,
use_cuda_graph_metrics=use_cuda_graph_trajopt_metrics,
)
finetune_js_trajopt_solver = TrajOptSolver(finetune_js_trajopt_cfg)
@@ -1379,6 +1388,26 @@ class MotionGenResult:
return current_tensor
@dataclass
class GraspPlanResult:
success: Optional[torch.Tensor] = None
grasp_trajectory: Optional[JointState] = None
grasp_trajectory_dt: Optional[torch.Tensor] = None
grasp_interpolated_trajectory: Optional[JointState] = None
grasp_interpolation_dt: Optional[torch.Tensor] = None
retract_trajectory: Optional[JointState] = None
retract_trajectory_dt: Optional[torch.Tensor] = None
retract_interpolated_trajectory: Optional[JointState] = None
retract_interpolation_dt: Optional[torch.Tensor] = None
approach_result: Optional[MotionGenResult] = None
grasp_result: Optional[MotionGenResult] = None
retract_result: Optional[MotionGenResult] = None
status: Optional[str] = None
goalset_result: Optional[MotionGenResult] = None
planning_time: float = 0.0
goalset_index: Optional[torch.Tensor] = None
class MotionGen(MotionGenConfig):
"""Motion generation wrapper for generating collision-free trajectories.
@@ -2167,10 +2196,16 @@ class MotionGen(MotionGenConfig):
Returns:
bool: True if the constraint can be added, False otherwise.
"""
rollouts = self.get_all_pose_rollout_instances()
# check if constraint is valid:
if metric.hold_partial_pose and metric.offset_tstep_fraction < 0.0:
start_pose = self.compute_kinematics(start_state).ee_pose.clone()
if self.project_pose_to_goal_frame:
project_distance = metric.project_to_goal_frame
if project_distance is None:
project_distance = rollouts[0].goal_cost.project_distance
if project_distance:
# project start pose to goal frame:
projected_pose = goal_pose.compute_local_pose(start_pose)
if torch.count_nonzero(metric.hold_vec_weight[:3] > 0.0) > 0:
@@ -2208,7 +2243,6 @@ class MotionGen(MotionGenConfig):
log_warn("Partial position between start and goal is not equal.")
return False
rollouts = self.get_all_pose_rollout_instances()
[
rollout.update_pose_cost_metric(metric)
for rollout in rollouts
@@ -2955,6 +2989,7 @@ class MotionGen(MotionGenConfig):
"""
start_time = time.time()
valid_query = True
plan_config = plan_config.clone()
if plan_config.check_start_validity:
valid_query, status = self.check_start_state(start_state)
if not valid_query:
@@ -3094,6 +3129,7 @@ class MotionGen(MotionGenConfig):
MotionGenResult: Result of batched planning.
"""
start_time = time.time()
plan_config = plan_config.clone()
goal_pose = goal_pose.clone()
if plan_config.pose_cost_metric is not None:
valid_query = self.update_pose_cost_metric(
@@ -4135,3 +4171,243 @@ class MotionGen(MotionGenConfig):
)
result = self.plan_batch(start_state, goal_pose, plan_config)
return result
def toggle_link_collision(self, collision_link_names: List[str], enable_flag: bool):
if len(collision_link_names) > 0:
if enable_flag:
for k in collision_link_names:
self.kinematics.kinematics_config.enable_link_spheres(k)
else:
for k in collision_link_names:
self.kinematics.kinematics_config.disable_link_spheres(k)
def plan_grasp(
self,
start_state: JointState,
grasp_poses: Pose,
plan_config: MotionGenPlanConfig,
grasp_approach_offset: Pose = Pose.from_list([0, 0, -0.15, 1, 0, 0, 0]),
grasp_approach_path_constraint: Union[None, List[float]] = [0.1, 0.1, 0.1, 0.1, 0.1, 0.0],
retract_offset: Pose = Pose.from_list([0, 0, -0.15, 1, 0, 0, 0]),
retract_path_constraint: Union[None, List[float]] = [0.1, 0.1, 0.1, 0.1, 0.1, 0.0],
disable_collision_links: List[str] = [],
plan_approach_to_grasp: bool = True,
plan_grasp_to_retract: bool = True,
grasp_approach_constraint_in_goal_frame: bool = True,
retract_constraint_in_goal_frame: bool = True,
) -> GraspPlanResult:
"""Plan a sequence of motions to grasp an object, given a set of grasp poses.
This function plans three motions, first approaches the object with an offset, then
moves with linear constraints to the grasp pose, and finally retracts the arm base to
offset with linear constraints. During the linear constrained motions, collision between
disable_collision_links and the world is disabled. This disabling is useful to enable
contact between a robot's gripper links and the object.
This method takes a set of grasp poses and finds the best grasp pose to reach based on a
goal set trajectory optimization problem. In this problem, the robot needs to reach one
of the poses in the grasp_poses set at the terminal state. To allow for in-contact grasps,
collision between disable_collision_links and world is disabled during the optimization.
The best grasp pose is then used to plan the three motions.
Args:
start_state: Start joint state for planning.
grasp_poses: Set of grasp poses, represented with :class:~curobo.math.types.Pose, of
shape (1, num_grasps, 7).
plan_config: Planning parameters for motion generation.
grasp_approach_offset: Offset pose from the grasp pose. Reference frame is the grasp
pose frame if grasp_approach_constraint_in_goal_frame is True, otherwise the
reference frame is the robot base frame.
grasp_approach_path_constraint: Path constraint for the approach to grasp pose and
grasp to retract path. This is a list of 6 values, where each value is a weight
for each Cartesian dimension. The first three are for orientation and the last
three are for position. If None, no path constraint is applied.
retract_offset: Retract offset pose from grasp pose. Reference frame is the grasp pose
frame if retract_constraint_in_goal_frame is True, otherwise the reference frame is
the robot base frame.
retract_path_constraint: Path constraint for the retract path. This is a list of 6
values, where each value is a weight for each Cartesian dimension. The first three
are for orientation and the last three are for position. If None, no path
constraint is applied.
disable_collision_links: Name of links to disable collision with the world during
the approach to grasp and grasp to retract path.
plan_approach_to_grasp: If True, planning also includes moving from approach to
grasp. If False, a plan to reach offset of the best grasp pose is returned.
plan_grasp_to_retract: If True, planning also includes moving from grasp to retract.
If False, only a plan to reach the best grasp pose is returned.
grasp_approach_constraint_in_goal_frame: If True, the grasp approach offset is in the
grasp pose frame. If False, the grasp approach offset is in the robot base frame.
Also applies to grasp_approach_path_constraint.
retract_constraint_in_goal_frame: If True, the retract offset is in the grasp pose
frame. If False, the retract offset is in the robot base frame. Also applies to
retract_path_constraint.
Returns:
GraspPlanResult: Result of planning. Use :meth:`GraspPlanResult.grasp_trajectory` to
get the trajectory to reach the grasp pose and
:meth:`GraspPlanResult.retract_trajectory` to get the trajectory to retract from
the grasp pose.
"""
if plan_config.pose_cost_metric is not None:
log_error("plan_config.pose_cost_metric should be None")
self.toggle_link_collision(disable_collision_links, False)
result = GraspPlanResult()
goalset_motion_gen_result = self.plan_goalset(
start_state,
grasp_poses,
plan_config,
)
self.toggle_link_collision(disable_collision_links, True)
result.success = goalset_motion_gen_result.success.clone()
result.success[:] = False
result.goalset_result = goalset_motion_gen_result
if not goalset_motion_gen_result.success.item():
result.status = "No grasp in goal set was reachable."
return result
result.goalset_index = goalset_motion_gen_result.goalset_index.clone()
# plan to offset:
goal_index = goalset_motion_gen_result.goalset_index.item()
goal_pose = grasp_poses.get_index(0, goal_index).clone()
if grasp_approach_constraint_in_goal_frame:
offset_goal_pose = goal_pose.clone().multiply(grasp_approach_offset)
else:
offset_goal_pose = grasp_approach_offset.clone().multiply(goal_pose.clone())
reach_offset_mg_result = self.plan_single(
start_state,
offset_goal_pose,
plan_config.clone(),
)
result.approach_result = reach_offset_mg_result
if not reach_offset_mg_result.success.item():
result.status = f"Planning to Approach pose failed: {reach_offset_mg_result.status}"
return result
if not plan_approach_to_grasp:
result.grasp_trajectory = reach_offset_mg_result.optimized_plan
result.grasp_trajectory_dt = reach_offset_mg_result.optimized_dt
result.grasp_interpolated_trajectory = reach_offset_mg_result.get_interpolated_plan()
result.grasp_interpolation_dt = reach_offset_mg_result.interpolation_dt
return result
# plan to final grasp
if grasp_approach_path_constraint is not None:
hold_pose_cost_metric = PoseCostMetric(
hold_partial_pose=True,
hold_vec_weight=self.tensor_args.to_device(grasp_approach_path_constraint),
project_to_goal_frame=grasp_approach_constraint_in_goal_frame,
)
plan_config.pose_cost_metric = hold_pose_cost_metric
offset_start_state = reach_offset_mg_result.optimized_plan[-1].unsqueeze(0)
self.toggle_link_collision(disable_collision_links, False)
reach_grasp_mg_result = self.plan_single(
offset_start_state,
goal_pose,
plan_config,
)
self.toggle_link_collision(disable_collision_links, True)
result.grasp_result = reach_grasp_mg_result
if not reach_grasp_mg_result.success.item():
result.status = (
f"Planning from Approach to Grasp Failed: {reach_grasp_mg_result.status}"
)
return result
# Get stitched trajectory:
offset_dt = reach_offset_mg_result.optimized_dt
grasp_dt = reach_grasp_mg_result.optimized_dt
if offset_dt > grasp_dt:
# retime grasp trajectory to match offset trajectory:
grasp_time_dilation = grasp_dt / offset_dt
reach_grasp_mg_result.retime_trajectory(
grasp_time_dilation,
interpolate_trajectory=True,
)
else:
offset_time_dilation = offset_dt / grasp_dt
reach_offset_mg_result.retime_trajectory(
offset_time_dilation,
interpolate_trajectory=True,
)
if (reach_offset_mg_result.optimized_dt - reach_grasp_mg_result.optimized_dt).abs() > 0.01:
reach_offset_mg_result.success[:] = False
if reach_offset_mg_result.debug_info is None:
reach_offset_mg_result.debug_info = {}
reach_offset_mg_result.debug_info["plan_single_grasp_status"] = (
"Stitching Trajectories Failed"
)
return reach_offset_mg_result, None
result.grasp_trajectory = reach_offset_mg_result.optimized_plan.stack(
reach_grasp_mg_result.optimized_plan
).clone()
result.grasp_trajectory_dt = reach_offset_mg_result.optimized_dt
result.grasp_interpolated_trajectory = (
reach_offset_mg_result.get_interpolated_plan()
.stack(reach_grasp_mg_result.get_interpolated_plan())
.clone()
)
result.grasp_interpolation_dt = reach_offset_mg_result.interpolation_dt
# update trajectories in results:
result.planning_time = (
reach_offset_mg_result.total_time
+ reach_grasp_mg_result.total_time
+ goalset_motion_gen_result.total_time
)
# check if retract path is required:
result.success[:] = True
if not plan_grasp_to_retract:
return result
result.success[:] = False
self.toggle_link_collision(disable_collision_links, False)
grasp_start_state = result.grasp_trajectory[-1].unsqueeze(0)
# compute retract goal pose:
if retract_constraint_in_goal_frame:
retract_goal_pose = goal_pose.clone().multiply(retract_offset)
else:
retract_goal_pose = retract_offset.clone().multiply(goal_pose.clone())
# add path constraint for retract:
plan_config.pose_cost_metric = None
if retract_path_constraint is not None:
hold_pose_cost_metric = PoseCostMetric(
hold_partial_pose=True,
hold_vec_weight=self.tensor_args.to_device(retract_path_constraint),
project_to_goal_frame=retract_constraint_in_goal_frame,
)
plan_config.pose_cost_metric = hold_pose_cost_metric
# plan from grasp pose to retract:
retract_grasp_mg_result = self.plan_single(
grasp_start_state,
retract_goal_pose,
plan_config,
)
self.toggle_link_collision(disable_collision_links, True)
result.planning_time += retract_grasp_mg_result.total_time
if not retract_grasp_mg_result.success.item():
result.status = f"Retract from Grasp failed: {retract_grasp_mg_result.status}"
result.retract_result = retract_grasp_mg_result
return result
result.success[:] = True
result.retract_trajectory = retract_grasp_mg_result.optimized_plan
result.retract_trajectory_dt = retract_grasp_mg_result.optimized_dt
result.retract_interpolated_trajectory = retract_grasp_mg_result.get_interpolated_plan()
result.retract_interpolation_dt = retract_grasp_mg_result.interpolation_dt
return result

View File

@@ -148,6 +148,7 @@ class TrajOptSolverConfig:
filter_robot_command: bool = False,
optimize_dt: bool = True,
project_pose_to_goal_frame: bool = True,
use_cuda_graph_metrics: bool = False,
):
"""Load TrajOptSolver configuration from robot configuration.
@@ -290,6 +291,10 @@ class TrajOptSolverConfig:
project_pose_to_goal_frame: Project pose to goal frame when calculating distance
between reached and goal pose. Use this to constrain motion to specific axes
either in the global frame or the goal frame.
use_cuda_graph_metrics: Flag to enable cuda_graph when evaluating interpolated
trajectories after trajectory optimization. If interpolation_buffer is smaller
than interpolated trajectory, then the buffers will be re-created. This can cause
existing cuda graph to be invalid.
Returns:
TrajOptSolverConfig: Trajectory optimization configuration.
@@ -508,7 +513,7 @@ class TrajOptSolverConfig:
safety_rollout=arm_rollout_safety,
optimizers=opt_list,
compute_metrics=True,
use_cuda_graph_metrics=use_cuda_graph,
use_cuda_graph_metrics=use_cuda_graph_metrics,
sync_cuda_time=sync_cuda_time,
)
trajopt = WrapBase(cfg)
@@ -539,7 +544,7 @@ class TrajOptSolverConfig:
tensor_args=tensor_args,
sync_cuda_time=sync_cuda_time,
interpolate_rollout=interpolate_rollout,
use_cuda_graph_metrics=use_cuda_graph,
use_cuda_graph_metrics=use_cuda_graph_metrics,
trim_steps=trim_steps,
store_debug_in_result=store_debug_in_result,
optimize_dt=optimize_dt,
@@ -720,7 +725,7 @@ class TrajOptSolver(TrajOptSolverConfig):
link_name: Name of the link to attach the spheres to. Note that this link should
already have pre-allocated spheres.
"""
self.kinematics.attach_object(
self.kinematics.kinematics_config.attach_object(
sphere_radius=sphere_radius, sphere_tensor=sphere_tensor, link_name=link_name
)
@@ -730,7 +735,7 @@ class TrajOptSolver(TrajOptSolverConfig):
Args:
link_name: Name of the link to detach the spheres from.
"""
self.kinematics.detach_object(link_name)
self.kinematics.kinematics_config.detach_object(link_name)
def _update_solve_state_and_goal_buffer(
self,