constrained planning, robot segmentation
This commit is contained in:
@@ -16,43 +16,55 @@
|
||||
|
||||
// CUDA forward declarations
|
||||
|
||||
std::vector<torch::Tensor> self_collision_distance(
|
||||
torch::Tensor out_distance, torch::Tensor out_vec,
|
||||
torch::Tensor sparse_index,
|
||||
const torch::Tensor robot_spheres, // batch_size x n_spheres x 4
|
||||
const torch::Tensor collision_offset, // n_spheres x n_spheres
|
||||
const torch::Tensor weight,
|
||||
const torch::Tensor collision_matrix, // n_spheres x n_spheres
|
||||
const torch::Tensor thread_locations, const int locations_size,
|
||||
const int batch_size, const int nspheres, const bool compute_grad = false,
|
||||
const int ndpt = 8, // Does this need to match template?
|
||||
const bool debug = false);
|
||||
std::vector<torch::Tensor>self_collision_distance(
|
||||
torch::Tensor out_distance,
|
||||
torch::Tensor out_vec,
|
||||
torch::Tensor sparse_index,
|
||||
const torch::Tensor robot_spheres, // batch_size x n_spheres x 4
|
||||
const torch::Tensor collision_offset, // n_spheres x n_spheres
|
||||
const torch::Tensor weight,
|
||||
const torch::Tensor collision_matrix, // n_spheres x n_spheres
|
||||
const torch::Tensor thread_locations,
|
||||
const int locations_size,
|
||||
const int batch_size,
|
||||
const int nspheres,
|
||||
const bool compute_grad = false,
|
||||
const int ndpt = 8, // Does this need to match template?
|
||||
const bool debug = false);
|
||||
|
||||
// CUDA forward declarations
|
||||
|
||||
std::vector<torch::Tensor> swept_sphere_obb_clpt(
|
||||
const torch::Tensor sphere_position, // batch_size, 3
|
||||
torch::Tensor distance, // batch_size, 1
|
||||
torch::Tensor
|
||||
closest_point, // batch size, 4 -> written out as x,y,z,0 for gradient
|
||||
torch::Tensor sparsity_idx, const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance, const torch::Tensor speed_dt,
|
||||
const torch::Tensor obb_accel, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_bounds, // n_boxes, 3
|
||||
const torch::Tensor obb_pose, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_enable, // n_boxes, 4,
|
||||
const torch::Tensor n_env_obb, // n_boxes, 4, 4
|
||||
const torch::Tensor env_query_idx, // n_boxes, 4, 4
|
||||
const int max_nobs, const int batch_size, const int horizon,
|
||||
const int n_spheres, const int sweep_steps, const bool enable_speed_metric,
|
||||
const bool transform_back, const bool compute_distance,
|
||||
const bool use_batch_env);
|
||||
std::vector<torch::Tensor>swept_sphere_obb_clpt(
|
||||
const torch::Tensor sphere_position, // batch_size, 3
|
||||
torch::Tensor distance, // batch_size, 1
|
||||
torch::Tensor
|
||||
closest_point, // batch size, 4 -> written out as x,y,z,0 for gradient
|
||||
torch::Tensor sparsity_idx,
|
||||
const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance,
|
||||
const torch::Tensor speed_dt,
|
||||
const torch::Tensor obb_accel, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_bounds, // n_boxes, 3
|
||||
const torch::Tensor obb_pose, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_enable, // n_boxes, 4,
|
||||
const torch::Tensor n_env_obb, // n_boxes, 4, 4
|
||||
const torch::Tensor env_query_idx, // n_boxes, 4, 4
|
||||
const int max_nobs,
|
||||
const int batch_size,
|
||||
const int horizon,
|
||||
const int n_spheres,
|
||||
const int sweep_steps,
|
||||
const bool enable_speed_metric,
|
||||
const bool transform_back,
|
||||
const bool compute_distance,
|
||||
const bool use_batch_env);
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
sphere_obb_clpt(const torch::Tensor sphere_position, // batch_size, 4
|
||||
torch::Tensor distance,
|
||||
torch::Tensor closest_point, // batch size, 3
|
||||
torch::Tensor sparsity_idx, const torch::Tensor weight,
|
||||
torch::Tensor distance,
|
||||
torch::Tensor closest_point, // batch size, 3
|
||||
torch::Tensor sparsity_idx,
|
||||
const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance,
|
||||
const torch::Tensor obb_accel, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_bounds, // n_boxes, 3
|
||||
@@ -60,58 +72,75 @@ sphere_obb_clpt(const torch::Tensor sphere_position, // batch_size, 4
|
||||
const torch::Tensor obb_enable, // n_boxes, 4, 4
|
||||
const torch::Tensor n_env_obb, // n_boxes, 4, 4
|
||||
const torch::Tensor env_query_idx, // n_boxes, 4, 4
|
||||
const int max_nobs, const int batch_size, const int horizon,
|
||||
const int n_spheres, const bool transform_back,
|
||||
const bool compute_distance, const bool use_batch_env);
|
||||
const int max_nobs,
|
||||
const int batch_size,
|
||||
const int horizon,
|
||||
const int n_spheres,
|
||||
const bool transform_back,
|
||||
const bool compute_distance,
|
||||
const bool use_batch_env);
|
||||
|
||||
std::vector<torch::Tensor> pose_distance(
|
||||
torch::Tensor out_distance, torch::Tensor out_position_distance,
|
||||
torch::Tensor out_rotation_distance,
|
||||
torch::Tensor distance_p_vector, // batch size, 3
|
||||
torch::Tensor distance_q_vector, // batch size, 4
|
||||
torch::Tensor out_gidx,
|
||||
const torch::Tensor current_position, // batch_size, 3
|
||||
const torch::Tensor goal_position, // n_boxes, 3
|
||||
const torch::Tensor current_quat, const torch::Tensor goal_quat,
|
||||
const torch::Tensor vec_weight, // n_boxes, 4, 4
|
||||
const torch::Tensor weight, // n_boxes, 4, 4
|
||||
const torch::Tensor vec_convergence, const torch::Tensor run_weight,
|
||||
const torch::Tensor run_vec_weight, const torch::Tensor batch_pose_idx,
|
||||
const int batch_size, const int horizon, const int mode,
|
||||
const int num_goals = 1, const bool compute_grad = false,
|
||||
const bool write_distance = true, const bool use_metric = false);
|
||||
std::vector<torch::Tensor>pose_distance(
|
||||
torch::Tensor out_distance,
|
||||
torch::Tensor out_position_distance,
|
||||
torch::Tensor out_rotation_distance,
|
||||
torch::Tensor distance_p_vector, // batch size, 3
|
||||
torch::Tensor distance_q_vector, // batch size, 4
|
||||
torch::Tensor out_gidx,
|
||||
const torch::Tensor current_position, // batch_size, 3
|
||||
const torch::Tensor goal_position, // n_boxes, 3
|
||||
const torch::Tensor current_quat,
|
||||
const torch::Tensor goal_quat,
|
||||
const torch::Tensor vec_weight, // n_boxes, 4, 4
|
||||
const torch::Tensor weight, // n_boxes, 4, 4
|
||||
const torch::Tensor vec_convergence,
|
||||
const torch::Tensor run_weight,
|
||||
const torch::Tensor run_vec_weight,
|
||||
const torch::Tensor offset_waypoint,
|
||||
const torch::Tensor offset_tstep_fraction,
|
||||
const torch::Tensor batch_pose_idx,
|
||||
const int batch_size,
|
||||
const int horizon,
|
||||
const int mode,
|
||||
const int num_goals = 1,
|
||||
const bool compute_grad = false,
|
||||
const bool write_distance = true,
|
||||
const bool use_metric = false,
|
||||
const bool project_distance = true);
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
backward_pose_distance(torch::Tensor out_grad_p, torch::Tensor out_grad_q,
|
||||
backward_pose_distance(torch::Tensor out_grad_p,
|
||||
torch::Tensor out_grad_q,
|
||||
const torch::Tensor grad_distance, // batch_size, 3
|
||||
const torch::Tensor grad_p_distance, // n_boxes, 3
|
||||
const torch::Tensor grad_q_distance,
|
||||
const torch::Tensor pose_weight,
|
||||
const torch::Tensor grad_p_vec, // n_boxes, 4, 4
|
||||
const torch::Tensor grad_q_vec, const int batch_size,
|
||||
const bool use_distance = false);
|
||||
const torch::Tensor grad_p_vec, // n_boxes, 4, 4
|
||||
const torch::Tensor grad_q_vec,
|
||||
const int batch_size,
|
||||
const bool use_distance = false);
|
||||
|
||||
// C++ interface
|
||||
|
||||
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), # x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), # x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<torch::Tensor> self_collision_distance_wrapper(
|
||||
torch::Tensor out_distance, torch::Tensor out_vec,
|
||||
torch::Tensor sparse_index,
|
||||
const torch::Tensor robot_spheres, // batch_size x n_spheres x 4
|
||||
const torch::Tensor collision_offset, // n_spheres
|
||||
const torch::Tensor weight,
|
||||
const torch::Tensor collision_matrix, // n_spheres
|
||||
const torch::Tensor thread_locations, const int thread_locations_size,
|
||||
const int batch_size, const int nspheres, const bool compute_grad = false,
|
||||
const int ndpt = 8, const bool debug = false) {
|
||||
|
||||
std::vector<torch::Tensor>self_collision_distance_wrapper(
|
||||
torch::Tensor out_distance, torch::Tensor out_vec,
|
||||
torch::Tensor sparse_index,
|
||||
const torch::Tensor robot_spheres, // batch_size x n_spheres x 4
|
||||
const torch::Tensor collision_offset, // n_spheres
|
||||
const torch::Tensor weight,
|
||||
const torch::Tensor collision_matrix, // n_spheres
|
||||
const torch::Tensor thread_locations, const int thread_locations_size,
|
||||
const int batch_size, const int nspheres, const bool compute_grad = false,
|
||||
const int ndpt = 8, const bool debug = false)
|
||||
{
|
||||
CHECK_INPUT(out_distance);
|
||||
CHECK_INPUT(out_vec);
|
||||
CHECK_INPUT(robot_spheres);
|
||||
@@ -123,29 +152,30 @@ std::vector<torch::Tensor> self_collision_distance_wrapper(
|
||||
const at::cuda::OptionalCUDAGuard guard(robot_spheres.device());
|
||||
|
||||
return self_collision_distance(
|
||||
out_distance, out_vec, sparse_index, robot_spheres,
|
||||
collision_offset, weight, collision_matrix, thread_locations,
|
||||
thread_locations_size, batch_size, nspheres, compute_grad, ndpt, debug);
|
||||
out_distance, out_vec, sparse_index, robot_spheres,
|
||||
collision_offset, weight, collision_matrix, thread_locations,
|
||||
thread_locations_size, batch_size, nspheres, compute_grad, ndpt, debug);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> sphere_obb_clpt_wrapper(
|
||||
const torch::Tensor sphere_position, // batch_size, 4
|
||||
|
||||
torch::Tensor distance,
|
||||
torch::Tensor closest_point, // batch size, 3
|
||||
torch::Tensor sparsity_idx, const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance,
|
||||
const torch::Tensor obb_accel, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_bounds, // n_boxes, 3
|
||||
const torch::Tensor obb_pose, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_enable, // n_boxes, 4, 4
|
||||
const torch::Tensor n_env_obb, // n_boxes, 4, 4
|
||||
const torch::Tensor env_query_idx, // n_boxes, 4, 4
|
||||
const int max_nobs, const int batch_size, const int horizon,
|
||||
const int n_spheres, const bool transform_back, const bool compute_distance,
|
||||
const bool use_batch_env) {
|
||||
std::vector<torch::Tensor>sphere_obb_clpt_wrapper(
|
||||
const torch::Tensor sphere_position, // batch_size, 4
|
||||
|
||||
torch::Tensor distance,
|
||||
torch::Tensor closest_point, // batch size, 3
|
||||
torch::Tensor sparsity_idx, const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance,
|
||||
const torch::Tensor obb_accel, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_bounds, // n_boxes, 3
|
||||
const torch::Tensor obb_pose, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_enable, // n_boxes, 4, 4
|
||||
const torch::Tensor n_env_obb, // n_boxes, 4, 4
|
||||
const torch::Tensor env_query_idx, // n_boxes, 4, 4
|
||||
const int max_nobs, const int batch_size, const int horizon,
|
||||
const int n_spheres, const bool transform_back, const bool compute_distance,
|
||||
const bool use_batch_env)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(sphere_position.device());
|
||||
|
||||
CHECK_INPUT(distance);
|
||||
CHECK_INPUT(closest_point);
|
||||
CHECK_INPUT(sphere_position);
|
||||
@@ -154,55 +184,61 @@ std::vector<torch::Tensor> sphere_obb_clpt_wrapper(
|
||||
CHECK_INPUT(activation_distance);
|
||||
CHECK_INPUT(obb_accel);
|
||||
return sphere_obb_clpt(
|
||||
sphere_position, distance, closest_point, sparsity_idx, weight,
|
||||
activation_distance, obb_accel, obb_bounds, obb_pose, obb_enable,
|
||||
n_env_obb, env_query_idx, max_nobs, batch_size, horizon, n_spheres,
|
||||
transform_back, compute_distance, use_batch_env);
|
||||
sphere_position, distance, closest_point, sparsity_idx, weight,
|
||||
activation_distance, obb_accel, obb_bounds, obb_pose, obb_enable,
|
||||
n_env_obb, env_query_idx, max_nobs, batch_size, horizon, n_spheres,
|
||||
transform_back, compute_distance, use_batch_env);
|
||||
}
|
||||
std::vector<torch::Tensor> swept_sphere_obb_clpt_wrapper(
|
||||
const torch::Tensor sphere_position, // batch_size, 4
|
||||
torch::Tensor distance,
|
||||
torch::Tensor closest_point, // batch size, 3
|
||||
torch::Tensor sparsity_idx, const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance, const torch::Tensor speed_dt,
|
||||
const torch::Tensor obb_accel, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_bounds, // n_boxes, 3
|
||||
const torch::Tensor obb_pose, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_enable, // n_boxes, 4, 4
|
||||
const torch::Tensor n_env_obb, // n_boxes, 4, 4
|
||||
const torch::Tensor env_query_idx, // n_boxes, 4, 4
|
||||
const int max_nobs, const int batch_size, const int horizon,
|
||||
const int n_spheres, const int sweep_steps, const bool enable_speed_metric,
|
||||
const bool transform_back, const bool compute_distance,
|
||||
const bool use_batch_env) {
|
||||
|
||||
std::vector<torch::Tensor>swept_sphere_obb_clpt_wrapper(
|
||||
const torch::Tensor sphere_position, // batch_size, 4
|
||||
torch::Tensor distance,
|
||||
torch::Tensor closest_point, // batch size, 3
|
||||
torch::Tensor sparsity_idx, const torch::Tensor weight,
|
||||
const torch::Tensor activation_distance, const torch::Tensor speed_dt,
|
||||
const torch::Tensor obb_accel, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_bounds, // n_boxes, 3
|
||||
const torch::Tensor obb_pose, // n_boxes, 4, 4
|
||||
const torch::Tensor obb_enable, // n_boxes, 4, 4
|
||||
const torch::Tensor n_env_obb, // n_boxes, 4, 4
|
||||
const torch::Tensor env_query_idx, // n_boxes, 4, 4
|
||||
const int max_nobs, const int batch_size, const int horizon,
|
||||
const int n_spheres, const int sweep_steps, const bool enable_speed_metric,
|
||||
const bool transform_back, const bool compute_distance,
|
||||
const bool use_batch_env)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(sphere_position.device());
|
||||
|
||||
CHECK_INPUT(distance);
|
||||
CHECK_INPUT(closest_point);
|
||||
CHECK_INPUT(sphere_position);
|
||||
|
||||
return swept_sphere_obb_clpt(
|
||||
sphere_position,
|
||||
distance, closest_point, sparsity_idx, weight, activation_distance,
|
||||
speed_dt, obb_accel, obb_bounds, obb_pose, obb_enable, n_env_obb,
|
||||
env_query_idx, max_nobs, batch_size, horizon, n_spheres, sweep_steps,
|
||||
enable_speed_metric, transform_back, compute_distance, use_batch_env);
|
||||
sphere_position,
|
||||
distance, closest_point, sparsity_idx, weight, activation_distance,
|
||||
speed_dt, obb_accel, obb_bounds, obb_pose, obb_enable, n_env_obb,
|
||||
env_query_idx, max_nobs, batch_size, horizon, n_spheres, sweep_steps,
|
||||
enable_speed_metric, transform_back, compute_distance, use_batch_env);
|
||||
}
|
||||
std::vector<torch::Tensor> pose_distance_wrapper(
|
||||
torch::Tensor out_distance, torch::Tensor out_position_distance,
|
||||
torch::Tensor out_rotation_distance,
|
||||
torch::Tensor distance_p_vector, // batch size, 3
|
||||
torch::Tensor distance_q_vector, // batch size, 4
|
||||
torch::Tensor out_gidx,
|
||||
const torch::Tensor current_position, // batch_size, 3
|
||||
const torch::Tensor goal_position, // n_boxes, 3
|
||||
const torch::Tensor current_quat, const torch::Tensor goal_quat,
|
||||
const torch::Tensor vec_weight, // n_boxes, 4, 4
|
||||
const torch::Tensor weight, const torch::Tensor vec_convergence,
|
||||
const torch::Tensor run_weight, const torch::Tensor run_vec_weight,
|
||||
const torch::Tensor batch_pose_idx, const int batch_size, const int horizon,
|
||||
const int mode, const int num_goals = 1, const bool compute_grad = false,
|
||||
const bool write_distance = false, const bool use_metric = false) {
|
||||
|
||||
std::vector<torch::Tensor>pose_distance_wrapper(
|
||||
torch::Tensor out_distance, torch::Tensor out_position_distance,
|
||||
torch::Tensor out_rotation_distance,
|
||||
torch::Tensor distance_p_vector, // batch size, 3
|
||||
torch::Tensor distance_q_vector, // batch size, 4
|
||||
torch::Tensor out_gidx,
|
||||
const torch::Tensor current_position, // batch_size, 3
|
||||
const torch::Tensor goal_position, // n_boxes, 3
|
||||
const torch::Tensor current_quat, const torch::Tensor goal_quat,
|
||||
const torch::Tensor vec_weight, // n_boxes, 4, 4
|
||||
const torch::Tensor weight, const torch::Tensor vec_convergence,
|
||||
const torch::Tensor run_weight, const torch::Tensor run_vec_weight,
|
||||
const torch::Tensor offset_waypoint, const torch::Tensor offset_tstep_fraction,
|
||||
const torch::Tensor batch_pose_idx, const int batch_size, const int horizon,
|
||||
const int mode, const int num_goals = 1, const bool compute_grad = false,
|
||||
const bool write_distance = false, const bool use_metric = false,
|
||||
const bool project_distance = true)
|
||||
{
|
||||
// at::cuda::DeviceGuard guard(angle.device());
|
||||
CHECK_INPUT(out_distance);
|
||||
CHECK_INPUT(out_position_distance);
|
||||
@@ -214,24 +250,30 @@ std::vector<torch::Tensor> pose_distance_wrapper(
|
||||
CHECK_INPUT(current_quat);
|
||||
CHECK_INPUT(goal_quat);
|
||||
CHECK_INPUT(batch_pose_idx);
|
||||
CHECK_INPUT(offset_waypoint);
|
||||
CHECK_INPUT(offset_tstep_fraction);
|
||||
const at::cuda::OptionalCUDAGuard guard(current_position.device());
|
||||
|
||||
return pose_distance(
|
||||
out_distance, out_position_distance, out_rotation_distance,
|
||||
distance_p_vector, distance_q_vector, out_gidx, current_position,
|
||||
goal_position, current_quat, goal_quat, vec_weight, weight,
|
||||
vec_convergence, run_weight, run_vec_weight, batch_pose_idx, batch_size,
|
||||
horizon, mode, num_goals, compute_grad, write_distance, use_metric);
|
||||
out_distance, out_position_distance, out_rotation_distance,
|
||||
distance_p_vector, distance_q_vector, out_gidx, current_position,
|
||||
goal_position, current_quat, goal_quat, vec_weight, weight,
|
||||
vec_convergence, run_weight, run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx, batch_size,
|
||||
horizon, mode, num_goals, compute_grad, write_distance, use_metric, project_distance);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> backward_pose_distance_wrapper(
|
||||
torch::Tensor out_grad_p, torch::Tensor out_grad_q,
|
||||
const torch::Tensor grad_distance, // batch_size, 3
|
||||
const torch::Tensor grad_p_distance, // n_boxes, 3
|
||||
const torch::Tensor grad_q_distance, const torch::Tensor pose_weight,
|
||||
const torch::Tensor grad_p_vec, // n_boxes, 4, 4
|
||||
const torch::Tensor grad_q_vec, const int batch_size,
|
||||
const bool use_distance) {
|
||||
std::vector<torch::Tensor>backward_pose_distance_wrapper(
|
||||
torch::Tensor out_grad_p, torch::Tensor out_grad_q,
|
||||
const torch::Tensor grad_distance, // batch_size, 3
|
||||
const torch::Tensor grad_p_distance, // n_boxes, 3
|
||||
const torch::Tensor grad_q_distance, const torch::Tensor pose_weight,
|
||||
const torch::Tensor grad_p_vec, // n_boxes, 4, 4
|
||||
const torch::Tensor grad_q_vec, const int batch_size,
|
||||
const bool use_distance)
|
||||
{
|
||||
CHECK_INPUT(out_grad_p);
|
||||
CHECK_INPUT(out_grad_q);
|
||||
CHECK_INPUT(grad_distance);
|
||||
@@ -241,18 +283,19 @@ std::vector<torch::Tensor> backward_pose_distance_wrapper(
|
||||
const at::cuda::OptionalCUDAGuard guard(grad_distance.device());
|
||||
|
||||
return backward_pose_distance(
|
||||
out_grad_p, out_grad_q, grad_distance, grad_p_distance, grad_q_distance,
|
||||
pose_weight, grad_p_vec, grad_q_vec, batch_size, use_distance);
|
||||
out_grad_p, out_grad_q, grad_distance, grad_p_distance, grad_q_distance,
|
||||
pose_weight, grad_p_vec, grad_q_vec, batch_size, use_distance);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("pose_distance", &pose_distance_wrapper, "Pose Distance (curobolib)");
|
||||
m.def("pose_distance_backward", &backward_pose_distance_wrapper,
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("pose_distance", &pose_distance_wrapper, "Pose Distance (curobolib)");
|
||||
m.def("pose_distance_backward", &backward_pose_distance_wrapper,
|
||||
"Pose Distance Backward (curobolib)");
|
||||
|
||||
m.def("closest_point", &sphere_obb_clpt_wrapper,
|
||||
m.def("closest_point", &sphere_obb_clpt_wrapper,
|
||||
"Closest Point OBB(curobolib)");
|
||||
m.def("swept_closest_point", &swept_sphere_obb_clpt_wrapper,
|
||||
m.def("swept_closest_point", &swept_sphere_obb_clpt_wrapper,
|
||||
"Swept Closest Point OBB(curobolib)");
|
||||
|
||||
m.def("self_collision_distance", &self_collision_distance_wrapper,
|
||||
|
||||
@@ -16,91 +16,112 @@
|
||||
// CUDA forward declarations
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
matrix_to_quaternion(torch::Tensor out_quat,
|
||||
matrix_to_quaternion(torch::Tensor out_quat,
|
||||
const torch::Tensor in_rot // batch_size, 3
|
||||
);
|
||||
);
|
||||
|
||||
std::vector<torch::Tensor> kin_fused_forward(
|
||||
torch::Tensor link_pos, torch::Tensor link_quat,
|
||||
torch::Tensor batch_robot_spheres, torch::Tensor global_cumul_mat,
|
||||
const torch::Tensor joint_vec, const torch::Tensor fixed_transform,
|
||||
const torch::Tensor robot_spheres, const torch::Tensor link_map,
|
||||
const torch::Tensor joint_map, const torch::Tensor joint_map_type,
|
||||
const torch::Tensor store_link_map, const torch::Tensor link_sphere_map,
|
||||
const int batch_size, const int n_spheres,
|
||||
const bool use_global_cumul = false);
|
||||
std::vector<torch::Tensor>kin_fused_forward(
|
||||
torch::Tensor link_pos,
|
||||
torch::Tensor link_quat,
|
||||
torch::Tensor batch_robot_spheres,
|
||||
torch::Tensor global_cumul_mat,
|
||||
const torch::Tensor joint_vec,
|
||||
const torch::Tensor fixed_transform,
|
||||
const torch::Tensor robot_spheres,
|
||||
const torch::Tensor link_map,
|
||||
const torch::Tensor joint_map,
|
||||
const torch::Tensor joint_map_type,
|
||||
const torch::Tensor store_link_map,
|
||||
const torch::Tensor link_sphere_map,
|
||||
const int batch_size,
|
||||
const int n_spheres,
|
||||
const bool use_global_cumul = false);
|
||||
|
||||
std::vector<torch::Tensor>kin_fused_backward_16t(
|
||||
torch::Tensor grad_out,
|
||||
const torch::Tensor grad_nlinks_pos,
|
||||
const torch::Tensor grad_nlinks_quat,
|
||||
const torch::Tensor grad_spheres,
|
||||
const torch::Tensor global_cumul_mat,
|
||||
const torch::Tensor joint_vec,
|
||||
const torch::Tensor fixed_transform,
|
||||
const torch::Tensor robot_spheres,
|
||||
const torch::Tensor link_map,
|
||||
const torch::Tensor joint_map,
|
||||
const torch::Tensor joint_map_type,
|
||||
const torch::Tensor store_link_map,
|
||||
const torch::Tensor link_sphere_map,
|
||||
const torch::Tensor link_chain_map,
|
||||
const int batch_size,
|
||||
const int n_spheres,
|
||||
const bool sparsity_opt = true,
|
||||
const bool use_global_cumul = false);
|
||||
|
||||
std::vector<torch::Tensor> kin_fused_backward_16t(
|
||||
torch::Tensor grad_out, const torch::Tensor grad_nlinks_pos,
|
||||
const torch::Tensor grad_nlinks_quat, const torch::Tensor grad_spheres,
|
||||
const torch::Tensor global_cumul_mat, const torch::Tensor joint_vec,
|
||||
const torch::Tensor fixed_transform, const torch::Tensor robot_spheres,
|
||||
const torch::Tensor link_map, const torch::Tensor joint_map,
|
||||
const torch::Tensor joint_map_type, const torch::Tensor store_link_map,
|
||||
const torch::Tensor link_sphere_map, const torch::Tensor link_chain_map,
|
||||
const int batch_size, const int n_spheres, const bool sparsity_opt = true,
|
||||
const bool use_global_cumul = false);
|
||||
// C++ interface
|
||||
|
||||
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), # x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), # x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<torch::Tensor> kin_forward_wrapper(
|
||||
torch::Tensor link_pos, torch::Tensor link_quat,
|
||||
torch::Tensor batch_robot_spheres, torch::Tensor global_cumul_mat,
|
||||
const torch::Tensor joint_vec, const torch::Tensor fixed_transform,
|
||||
const torch::Tensor robot_spheres, const torch::Tensor link_map,
|
||||
const torch::Tensor joint_map, const torch::Tensor joint_map_type,
|
||||
const torch::Tensor store_link_map, const torch::Tensor link_sphere_map,
|
||||
const int batch_size, const int n_spheres,
|
||||
const bool use_global_cumul = false) {
|
||||
|
||||
std::vector<torch::Tensor>kin_forward_wrapper(
|
||||
torch::Tensor link_pos, torch::Tensor link_quat,
|
||||
torch::Tensor batch_robot_spheres, torch::Tensor global_cumul_mat,
|
||||
const torch::Tensor joint_vec, const torch::Tensor fixed_transform,
|
||||
const torch::Tensor robot_spheres, const torch::Tensor link_map,
|
||||
const torch::Tensor joint_map, const torch::Tensor joint_map_type,
|
||||
const torch::Tensor store_link_map, const torch::Tensor link_sphere_map,
|
||||
const int batch_size, const int n_spheres,
|
||||
const bool use_global_cumul = false)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(joint_vec.device());
|
||||
|
||||
// TODO: add check input
|
||||
return kin_fused_forward(
|
||||
link_pos, link_quat, batch_robot_spheres, global_cumul_mat, joint_vec,
|
||||
fixed_transform, robot_spheres, link_map, joint_map, joint_map_type,
|
||||
store_link_map, link_sphere_map, batch_size, n_spheres, use_global_cumul);
|
||||
link_pos, link_quat, batch_robot_spheres, global_cumul_mat, joint_vec,
|
||||
fixed_transform, robot_spheres, link_map, joint_map, joint_map_type,
|
||||
store_link_map, link_sphere_map, batch_size, n_spheres, use_global_cumul);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> kin_backward_wrapper(
|
||||
torch::Tensor grad_out, const torch::Tensor grad_nlinks_pos,
|
||||
const torch::Tensor grad_nlinks_quat, const torch::Tensor grad_spheres,
|
||||
const torch::Tensor global_cumul_mat, const torch::Tensor joint_vec,
|
||||
const torch::Tensor fixed_transform, const torch::Tensor robot_spheres,
|
||||
const torch::Tensor link_map, const torch::Tensor joint_map,
|
||||
const torch::Tensor joint_map_type, const torch::Tensor store_link_map,
|
||||
const torch::Tensor link_sphere_map, const torch::Tensor link_chain_map,
|
||||
const int batch_size, const int n_spheres, const bool sparsity_opt = true,
|
||||
const bool use_global_cumul = false) {
|
||||
std::vector<torch::Tensor>kin_backward_wrapper(
|
||||
torch::Tensor grad_out, const torch::Tensor grad_nlinks_pos,
|
||||
const torch::Tensor grad_nlinks_quat, const torch::Tensor grad_spheres,
|
||||
const torch::Tensor global_cumul_mat, const torch::Tensor joint_vec,
|
||||
const torch::Tensor fixed_transform, const torch::Tensor robot_spheres,
|
||||
const torch::Tensor link_map, const torch::Tensor joint_map,
|
||||
const torch::Tensor joint_map_type, const torch::Tensor store_link_map,
|
||||
const torch::Tensor link_sphere_map, const torch::Tensor link_chain_map,
|
||||
const int batch_size, const int n_spheres, const bool sparsity_opt = true,
|
||||
const bool use_global_cumul = false)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(joint_vec.device());
|
||||
|
||||
return kin_fused_backward_16t(
|
||||
grad_out, grad_nlinks_pos, grad_nlinks_quat, grad_spheres,
|
||||
global_cumul_mat, joint_vec, fixed_transform, robot_spheres, link_map,
|
||||
joint_map, joint_map_type, store_link_map, link_sphere_map,
|
||||
link_chain_map, batch_size, n_spheres, sparsity_opt, use_global_cumul);
|
||||
grad_out, grad_nlinks_pos, grad_nlinks_quat, grad_spheres,
|
||||
global_cumul_mat, joint_vec, fixed_transform, robot_spheres, link_map,
|
||||
joint_map, joint_map_type, store_link_map, link_sphere_map,
|
||||
link_chain_map, batch_size, n_spheres, sparsity_opt, use_global_cumul);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
matrix_to_quaternion_wrapper(torch::Tensor out_quat,
|
||||
matrix_to_quaternion_wrapper(torch::Tensor out_quat,
|
||||
const torch::Tensor in_rot // batch_size, 3
|
||||
) {
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(in_rot.device());
|
||||
|
||||
CHECK_INPUT(in_rot);
|
||||
CHECK_INPUT(out_quat);
|
||||
return matrix_to_quaternion(out_quat, in_rot);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &kin_forward_wrapper, "Kinematics fused forward (CUDA)");
|
||||
m.def("backward", &kin_backward_wrapper, "Kinematics fused backward (CUDA)");
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("forward", &kin_forward_wrapper, "Kinematics fused forward (CUDA)");
|
||||
m.def("backward", &kin_backward_wrapper, "Kinematics fused backward (CUDA)");
|
||||
m.def("matrix_to_quaternion", &matrix_to_quaternion_wrapper,
|
||||
"Rotation Matrix to Quaternion (CUDA)");
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,45 +15,68 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
std::vector<torch::Tensor> reduce_cuda(torch::Tensor vec, torch::Tensor vec2,
|
||||
torch::Tensor rho_buffer,
|
||||
torch::Tensor sum, const int batch_size,
|
||||
const int v_dim, const int m);
|
||||
std::vector<torch::Tensor>reduce_cuda(torch::Tensor vec,
|
||||
torch::Tensor vec2,
|
||||
torch::Tensor rho_buffer,
|
||||
torch::Tensor sum,
|
||||
const int batch_size,
|
||||
const int v_dim,
|
||||
const int m);
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
lbfgs_step_cuda(torch::Tensor step_vec, torch::Tensor rho_buffer,
|
||||
torch::Tensor y_buffer, torch::Tensor s_buffer,
|
||||
torch::Tensor grad_q, const float epsilon, const int batch_size,
|
||||
const int m, const int v_dim);
|
||||
lbfgs_step_cuda(torch::Tensor step_vec,
|
||||
torch::Tensor rho_buffer,
|
||||
torch::Tensor y_buffer,
|
||||
torch::Tensor s_buffer,
|
||||
torch::Tensor grad_q,
|
||||
const float epsilon,
|
||||
const int batch_size,
|
||||
const int m,
|
||||
const int v_dim);
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
lbfgs_update_cuda(torch::Tensor rho_buffer, torch::Tensor y_buffer,
|
||||
torch::Tensor s_buffer, torch::Tensor q, torch::Tensor grad_q,
|
||||
torch::Tensor x_0, torch::Tensor grad_0, const int batch_size,
|
||||
const int m, const int v_dim);
|
||||
lbfgs_update_cuda(torch::Tensor rho_buffer,
|
||||
torch::Tensor y_buffer,
|
||||
torch::Tensor s_buffer,
|
||||
torch::Tensor q,
|
||||
torch::Tensor grad_q,
|
||||
torch::Tensor x_0,
|
||||
torch::Tensor grad_0,
|
||||
const int batch_size,
|
||||
const int m,
|
||||
const int v_dim);
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
lbfgs_cuda_fuse(torch::Tensor step_vec, torch::Tensor rho_buffer,
|
||||
torch::Tensor y_buffer, torch::Tensor s_buffer, torch::Tensor q,
|
||||
torch::Tensor grad_q, torch::Tensor x_0, torch::Tensor grad_0,
|
||||
const float epsilon, const int batch_size, const int m,
|
||||
const int v_dim, const bool stable_mode);
|
||||
lbfgs_cuda_fuse(torch::Tensor step_vec,
|
||||
torch::Tensor rho_buffer,
|
||||
torch::Tensor y_buffer,
|
||||
torch::Tensor s_buffer,
|
||||
torch::Tensor q,
|
||||
torch::Tensor grad_q,
|
||||
torch::Tensor x_0,
|
||||
torch::Tensor grad_0,
|
||||
const float epsilon,
|
||||
const int batch_size,
|
||||
const int m,
|
||||
const int v_dim,
|
||||
const bool stable_mode);
|
||||
|
||||
// C++ interface
|
||||
|
||||
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), # x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), # x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
lbfgs_step_call(torch::Tensor step_vec, torch::Tensor rho_buffer,
|
||||
torch::Tensor y_buffer, torch::Tensor s_buffer,
|
||||
torch::Tensor grad_q, const float epsilon, const int batch_size,
|
||||
const int m, const int v_dim) {
|
||||
|
||||
const int m, const int v_dim)
|
||||
{
|
||||
CHECK_INPUT(step_vec);
|
||||
CHECK_INPUT(rho_buffer);
|
||||
CHECK_INPUT(y_buffer);
|
||||
@@ -69,7 +92,8 @@ std::vector<torch::Tensor>
|
||||
lbfgs_update_call(torch::Tensor rho_buffer, torch::Tensor y_buffer,
|
||||
torch::Tensor s_buffer, torch::Tensor q, torch::Tensor grad_q,
|
||||
torch::Tensor x_0, torch::Tensor grad_0, const int batch_size,
|
||||
const int m, const int v_dim) {
|
||||
const int m, const int v_dim)
|
||||
{
|
||||
CHECK_INPUT(rho_buffer);
|
||||
CHECK_INPUT(y_buffer);
|
||||
CHECK_INPUT(s_buffer);
|
||||
@@ -86,7 +110,8 @@ lbfgs_update_call(torch::Tensor rho_buffer, torch::Tensor y_buffer,
|
||||
std::vector<torch::Tensor>
|
||||
reduce_cuda_call(torch::Tensor vec, torch::Tensor vec2,
|
||||
torch::Tensor rho_buffer, torch::Tensor sum,
|
||||
const int batch_size, const int v_dim, const int m) {
|
||||
const int batch_size, const int v_dim, const int m)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(sum.device());
|
||||
|
||||
return reduce_cuda(vec, vec2, rho_buffer, sum, batch_size, v_dim, m);
|
||||
@@ -97,7 +122,8 @@ lbfgs_call(torch::Tensor step_vec, torch::Tensor rho_buffer,
|
||||
torch::Tensor y_buffer, torch::Tensor s_buffer, torch::Tensor q,
|
||||
torch::Tensor grad_q, torch::Tensor x_0, torch::Tensor grad_0,
|
||||
const float epsilon, const int batch_size, const int m,
|
||||
const int v_dim, const bool stable_mode) {
|
||||
const int v_dim, const bool stable_mode)
|
||||
{
|
||||
CHECK_INPUT(step_vec);
|
||||
CHECK_INPUT(rho_buffer);
|
||||
CHECK_INPUT(y_buffer);
|
||||
@@ -113,9 +139,10 @@ lbfgs_call(torch::Tensor step_vec, torch::Tensor rho_buffer,
|
||||
stable_mode);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("step", &lbfgs_step_call, "L-BFGS step (CUDA)");
|
||||
m.def("update", &lbfgs_update_call, "L-BFGS Update (CUDA)");
|
||||
m.def("forward", &lbfgs_call, "L-BFGS Update + Step (CUDA)");
|
||||
m.def("debug_reduce", &reduce_cuda_call, "L-BFGS Debug");
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("step", &lbfgs_step_call, "L-BFGS step (CUDA)");
|
||||
m.def("update", &lbfgs_update_call, "L-BFGS Update (CUDA)");
|
||||
m.def("forward", &lbfgs_call, "L-BFGS Update + Step (CUDA)");
|
||||
m.def("debug_reduce", &reduce_cuda_call, "L-BFGS Debug");
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,49 +18,67 @@
|
||||
// CUDA forward declarations
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
update_best_cuda(torch::Tensor best_cost, torch::Tensor best_q,
|
||||
torch::Tensor best_iteration,
|
||||
torch::Tensor current_iteration,
|
||||
update_best_cuda(torch::Tensor best_cost,
|
||||
torch::Tensor best_q,
|
||||
torch::Tensor best_iteration,
|
||||
torch::Tensor current_iteration,
|
||||
const torch::Tensor cost,
|
||||
const torch::Tensor q, const int d_opt, const int cost_s1,
|
||||
const int cost_s2, const int iteration,
|
||||
const float delta_threshold,
|
||||
const float relative_threshold = 0.999);
|
||||
const torch::Tensor q,
|
||||
const int d_opt,
|
||||
const int cost_s1,
|
||||
const int cost_s2,
|
||||
const int iteration,
|
||||
const float delta_threshold,
|
||||
const float relative_threshold = 0.999);
|
||||
|
||||
std::vector<torch::Tensor>line_search_cuda(
|
||||
|
||||
// torch::Tensor m,
|
||||
torch::Tensor best_x,
|
||||
torch::Tensor best_c,
|
||||
torch::Tensor best_grad,
|
||||
const torch::Tensor g_x,
|
||||
const torch::Tensor x_set,
|
||||
const torch::Tensor step_vec,
|
||||
const torch::Tensor c_0,
|
||||
const torch::Tensor alpha_list,
|
||||
const torch::Tensor c_idx,
|
||||
const float c_1,
|
||||
const float c_2,
|
||||
const bool strong_wolfe,
|
||||
const bool approx_wolfe,
|
||||
const int l1,
|
||||
const int l2,
|
||||
const int batchsize);
|
||||
|
||||
std::vector<torch::Tensor> line_search_cuda(
|
||||
// torch::Tensor m,
|
||||
torch::Tensor best_x, torch::Tensor best_c, torch::Tensor best_grad,
|
||||
const torch::Tensor g_x, const torch::Tensor x_set,
|
||||
const torch::Tensor step_vec, const torch::Tensor c_0,
|
||||
const torch::Tensor alpha_list, const torch::Tensor c_idx, const float c_1,
|
||||
const float c_2, const bool strong_wolfe, const bool approx_wolfe,
|
||||
const int l1, const int l2, const int batchsize);
|
||||
// C++ interface
|
||||
|
||||
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), # x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), # x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
|
||||
std::vector<torch::Tensor> line_search_call(
|
||||
// torch::Tensor m,
|
||||
torch::Tensor best_x, torch::Tensor best_c, torch::Tensor best_grad,
|
||||
const torch::Tensor g_x, const torch::Tensor x_set,
|
||||
const torch::Tensor step_vec, const torch::Tensor c_0,
|
||||
const torch::Tensor alpha_list, const torch::Tensor c_idx, const float c_1,
|
||||
const float c_2, const bool strong_wolfe, const bool approx_wolfe,
|
||||
const int l1, const int l2, const int batchsize) {
|
||||
std::vector<torch::Tensor>line_search_call(
|
||||
|
||||
// torch::Tensor m,
|
||||
torch::Tensor best_x, torch::Tensor best_c, torch::Tensor best_grad,
|
||||
const torch::Tensor g_x, const torch::Tensor x_set,
|
||||
const torch::Tensor step_vec, const torch::Tensor c_0,
|
||||
const torch::Tensor alpha_list, const torch::Tensor c_idx, const float c_1,
|
||||
const float c_2, const bool strong_wolfe, const bool approx_wolfe,
|
||||
const int l1, const int l2, const int batchsize)
|
||||
{
|
||||
CHECK_INPUT(g_x);
|
||||
CHECK_INPUT(x_set);
|
||||
CHECK_INPUT(step_vec);
|
||||
CHECK_INPUT(c_0);
|
||||
CHECK_INPUT(alpha_list);
|
||||
CHECK_INPUT(c_idx);
|
||||
|
||||
// CHECK_INPUT(m);
|
||||
CHECK_INPUT(best_x);
|
||||
CHECK_INPUT(best_c);
|
||||
@@ -76,14 +94,14 @@ std::vector<torch::Tensor> line_search_call(
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
update_best_call(torch::Tensor best_cost, torch::Tensor best_q,
|
||||
torch::Tensor best_iteration,
|
||||
torch::Tensor best_iteration,
|
||||
torch::Tensor current_iteration,
|
||||
const torch::Tensor cost,
|
||||
const torch::Tensor q, const int d_opt, const int cost_s1,
|
||||
const int cost_s2, const int iteration,
|
||||
const float delta_threshold,
|
||||
const float relative_threshold=0.999) {
|
||||
|
||||
const float relative_threshold = 0.999)
|
||||
{
|
||||
CHECK_INPUT(best_cost);
|
||||
CHECK_INPUT(best_q);
|
||||
CHECK_INPUT(cost);
|
||||
@@ -96,8 +114,8 @@ update_best_call(torch::Tensor best_cost, torch::Tensor best_q,
|
||||
cost_s1, cost_s2, iteration, delta_threshold, relative_threshold);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("update_best", &update_best_call, "Update Best (CUDA)");
|
||||
m.def("line_search", &line_search_call, "Line search (CUDA)");
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,362 +37,431 @@
|
||||
|
||||
#define FULL_MASK 0xffffffff
|
||||
|
||||
namespace Curobo {
|
||||
namespace Curobo
|
||||
{
|
||||
namespace Optimization
|
||||
{
|
||||
template<typename scalar_t, typename psum_t>
|
||||
__inline__ __device__ void reduce(scalar_t v, int m, unsigned mask,
|
||||
psum_t *data, scalar_t *result)
|
||||
{
|
||||
psum_t val = v;
|
||||
|
||||
namespace Optimization {
|
||||
val += __shfl_down_sync(mask, val, 1);
|
||||
val += __shfl_down_sync(mask, val, 2);
|
||||
val += __shfl_down_sync(mask, val, 4);
|
||||
val += __shfl_down_sync(mask, val, 8);
|
||||
val += __shfl_down_sync(mask, val, 16);
|
||||
|
||||
template <typename scalar_t, typename psum_t>
|
||||
__inline__ __device__ void reduce(scalar_t v, int m, unsigned mask,
|
||||
psum_t *data, scalar_t *result) {
|
||||
psum_t val = v;
|
||||
val += __shfl_down_sync(mask, val, 1);
|
||||
val += __shfl_down_sync(mask, val, 2);
|
||||
val += __shfl_down_sync(mask, val, 4);
|
||||
val += __shfl_down_sync(mask, val, 8);
|
||||
val += __shfl_down_sync(mask, val, 16);
|
||||
// int leader = __ffs(mask) – 1; // select a leader lane
|
||||
int leader = 0;
|
||||
if (threadIdx.x % 32 == leader) {
|
||||
if (m <= 32) {
|
||||
result[0] = (scalar_t)val;
|
||||
} else {
|
||||
data[(threadIdx.x + 1) / 32] = val;
|
||||
}
|
||||
}
|
||||
if (m > 32) {
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int elems = (m + 31) / 32;
|
||||
assert(elems <= 32);
|
||||
unsigned mask2 = __ballot_sync(FULL_MASK, threadIdx.x < elems);
|
||||
if (threadIdx.x < elems) { // only the first warp will do this work
|
||||
psum_t val2 = data[threadIdx.x % 32];
|
||||
int shift = 1;
|
||||
for (int i = elems - 1; i > 0; i /= 2) {
|
||||
val2 += __shfl_down_sync(mask2, val2, shift);
|
||||
shift *= 2;
|
||||
}
|
||||
// int leader = __ffs(mask2) – 1; // select a leader lane
|
||||
// int leader = __ffs(mask) – 1; // select a leader lane
|
||||
int leader = 0;
|
||||
if (threadIdx.x % 32 == leader) {
|
||||
result[0] = (scalar_t)val2;
|
||||
|
||||
if (threadIdx.x % 32 == leader)
|
||||
{
|
||||
if (m <= 32)
|
||||
{
|
||||
result[0] = (scalar_t)val;
|
||||
}
|
||||
else
|
||||
{
|
||||
data[(threadIdx.x + 1) / 32] = val;
|
||||
}
|
||||
}
|
||||
|
||||
if (m > 32)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
int elems = (m + 31) / 32;
|
||||
assert(elems <= 32);
|
||||
unsigned mask2 = __ballot_sync(FULL_MASK, threadIdx.x < elems);
|
||||
|
||||
if (threadIdx.x < elems) // only the first warp will do this work
|
||||
{
|
||||
psum_t val2 = data[threadIdx.x % 32];
|
||||
int shift = 1;
|
||||
|
||||
for (int i = elems - 1; i > 0; i /= 2)
|
||||
{
|
||||
val2 += __shfl_down_sync(mask2, val2, shift);
|
||||
shift *= 2;
|
||||
}
|
||||
|
||||
// int leader = __ffs(mask2) – 1; // select a leader lane
|
||||
int leader = 0;
|
||||
|
||||
if (threadIdx.x % 32 == leader)
|
||||
{
|
||||
result[0] = (scalar_t)val2;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Launched with l2 threads/block and batchsize blocks
|
||||
template<typename scalar_t, typename psum_t>
|
||||
__global__ void line_search_kernel(
|
||||
|
||||
// int64_t *m_idx, // 4x1x1
|
||||
scalar_t *best_x, // 4x280
|
||||
scalar_t *best_c, // 4x1
|
||||
scalar_t *best_grad, // 4x280
|
||||
const scalar_t *g_x, // 4x6x280
|
||||
const scalar_t *x_set, // 4x6x280
|
||||
const scalar_t *step_vec, // 4x280x1
|
||||
const scalar_t *c, // 4x6x1
|
||||
const scalar_t *alpha_list, // 4x6x1
|
||||
const int64_t *c_idx, // 4x1x1
|
||||
const float c_1, const float c_2, const bool strong_wolfe,
|
||||
const bool approx_wolfe,
|
||||
const int l1, // 6
|
||||
const int l2, // 280
|
||||
const int batchsize) // 4
|
||||
{
|
||||
int batch = blockIdx.x;
|
||||
__shared__ psum_t data[32];
|
||||
__shared__ scalar_t result[32];
|
||||
|
||||
assert(l1 <= 32);
|
||||
unsigned mask = __ballot_sync(FULL_MASK, threadIdx.x < l2);
|
||||
|
||||
if (threadIdx.x >= l2)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
scalar_t sv_elem = step_vec[batch * l2 + threadIdx.x];
|
||||
|
||||
// g_step = g0 @ step_vec_T
|
||||
// g_x @ step_vec_T
|
||||
for (int i = 0; i < l1; i++)
|
||||
{
|
||||
reduce(g_x[batch * l1 * l2 + l2 * i + threadIdx.x] * sv_elem, l2, mask,
|
||||
&data[0], &result[i]);
|
||||
}
|
||||
|
||||
__shared__ scalar_t step_success[32];
|
||||
__shared__ scalar_t step_success_w1[32];
|
||||
assert(blockDim.x >= l1);
|
||||
bool wolfe_1 = false;
|
||||
bool wolfe = false;
|
||||
bool condition = threadIdx.x < l1;
|
||||
|
||||
if (condition)
|
||||
{
|
||||
// scalar_t alpha_list_elem = alpha_list[batch*l1 + threadIdx.x];
|
||||
scalar_t alpha_list_elem = alpha_list[threadIdx.x];
|
||||
|
||||
// condition 1:
|
||||
wolfe_1 = c[batch * l1 + threadIdx.x] <=
|
||||
(c[batch * l1] + c_1 * alpha_list_elem * result[0]);
|
||||
|
||||
// condition 2:
|
||||
bool wolfe_2;
|
||||
|
||||
if (strong_wolfe)
|
||||
{
|
||||
wolfe_2 = abs(result[threadIdx.x]) <= c_2 *abs(result[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
wolfe_2 = result[threadIdx.x] >= c_2 * result[0];
|
||||
}
|
||||
|
||||
wolfe = wolfe_1 & wolfe_2;
|
||||
|
||||
step_success[threadIdx.x] = wolfe * (alpha_list_elem + 0.1);
|
||||
step_success_w1[threadIdx.x] = wolfe_1 * (alpha_list_elem + 0.1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
__shared__ int idx_shared;
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
int m_id = 0;
|
||||
int m1_id = 0;
|
||||
scalar_t max1 = step_success[0];
|
||||
scalar_t max2 = step_success_w1[0];
|
||||
|
||||
for (int i = 1; i < l1; i++)
|
||||
{
|
||||
if (max1 < step_success[i])
|
||||
{
|
||||
max1 = step_success[i];
|
||||
m_id = i;
|
||||
}
|
||||
|
||||
if (max2 < step_success_w1[i])
|
||||
{
|
||||
max2 = step_success_w1[i];
|
||||
m1_id = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (!approx_wolfe)
|
||||
{
|
||||
// m_idx = torch.where(m_idx == 0, m1_idx, m_idx)
|
||||
if (m_id == 0)
|
||||
{
|
||||
m_id = m1_id;
|
||||
}
|
||||
|
||||
// m_idx[m_idx == 0] = 1
|
||||
if (m_id == 0)
|
||||
{
|
||||
m_id = 1;
|
||||
}
|
||||
}
|
||||
idx_shared = m_id + c_idx[batch];
|
||||
}
|
||||
|
||||
////////////////////////////////////
|
||||
// write outputs using the computed index.
|
||||
// one index per batch is computed
|
||||
////////////////////////////////////
|
||||
// l2 is d_opt, l1 is line_search n.
|
||||
// idx_shared contains index in l1
|
||||
//
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < l2)
|
||||
{
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
// printf("block: %d, idx_shared: %d\n", batch, idx_shared);
|
||||
}
|
||||
best_x[batch * l2 + threadIdx.x] = x_set[idx_shared * l2 + threadIdx.x];
|
||||
best_grad[batch * l2 + threadIdx.x] = g_x[idx_shared * l2 + threadIdx.x];
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
best_c[batch] = c[idx_shared];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Launched with l2 threads/block and batchsize blocks
|
||||
template <typename scalar_t, typename psum_t>
|
||||
__global__ void line_search_kernel(
|
||||
// int64_t *m_idx, // 4x1x1
|
||||
scalar_t *best_x, // 4x280
|
||||
scalar_t *best_c, // 4x1
|
||||
scalar_t *best_grad, // 4x280
|
||||
const scalar_t *g_x, // 4x6x280
|
||||
const scalar_t *x_set, // 4x6x280
|
||||
const scalar_t *step_vec, // 4x280x1
|
||||
const scalar_t *c, // 4x6x1
|
||||
const scalar_t *alpha_list, // 4x6x1
|
||||
const int64_t *c_idx, // 4x1x1
|
||||
const float c_1, const float c_2, const bool strong_wolfe,
|
||||
const bool approx_wolfe,
|
||||
const int l1, // 6
|
||||
const int l2, // 280
|
||||
const int batchsize) // 4
|
||||
// Launched with l2 threads/block and #blocks = batchsize
|
||||
template<typename scalar_t, typename psum_t>
|
||||
__global__ void line_search_kernel_mask(
|
||||
|
||||
// int64_t *m_idx, // 4x1x1
|
||||
scalar_t *best_x, // 4x280
|
||||
scalar_t *best_c, // 4x1
|
||||
scalar_t *best_grad, // 4x280
|
||||
const scalar_t *g_x, // 4x6x280
|
||||
const scalar_t *x_set, // 4x6x280
|
||||
const scalar_t *step_vec, // 4x280x1
|
||||
const scalar_t *c, // 4x6x1
|
||||
const scalar_t *alpha_list, // 4x6x1
|
||||
const int64_t *c_idx, // 4x1x1
|
||||
const float c_1, const float c_2, const bool strong_wolfe,
|
||||
const bool approx_wolfe,
|
||||
const int l1, // 6
|
||||
const int l2, // 280
|
||||
const int batchsize) // 4
|
||||
{
|
||||
int batch = blockIdx.x;
|
||||
__shared__ psum_t data[32];
|
||||
__shared__ scalar_t result[32];
|
||||
|
||||
assert(l1 <= 32);
|
||||
unsigned mask = __ballot_sync(FULL_MASK, threadIdx.x < l2);
|
||||
|
||||
if (threadIdx.x >= l2)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
scalar_t sv_elem = step_vec[batch * l2 + threadIdx.x];
|
||||
|
||||
// g_step = g0 @ step_vec_T
|
||||
// g_x @ step_vec_T
|
||||
for (int i = 0; i < l1; i++)
|
||||
{
|
||||
reduce(g_x[batch * l1 * l2 + l2 * i + threadIdx.x] * sv_elem, l2, mask,
|
||||
&data[0], &result[i]);
|
||||
}
|
||||
|
||||
// __shared__ scalar_t step_success[32];
|
||||
// __shared__ scalar_t step_success_w1[32];
|
||||
assert(blockDim.x >= l1);
|
||||
bool wolfe_1 = false;
|
||||
bool wolfe = false;
|
||||
bool condition = threadIdx.x < l1;
|
||||
|
||||
if (condition)
|
||||
{
|
||||
scalar_t alpha_list_elem = alpha_list[threadIdx.x];
|
||||
|
||||
// scalar_t alpha_list_elem = alpha_list[batch*l1 + threadIdx.x];
|
||||
|
||||
// condition 1:
|
||||
wolfe_1 = c[batch * l1 + threadIdx.x] <=
|
||||
(c[batch * l1] + c_1 * alpha_list_elem * result[0]);
|
||||
|
||||
// condition 2:
|
||||
bool wolfe_2;
|
||||
|
||||
if (strong_wolfe)
|
||||
{
|
||||
wolfe_2 = abs(result[threadIdx.x]) <= c_2 *abs(result[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
wolfe_2 = result[threadIdx.x] >= c_2 * result[0];
|
||||
}
|
||||
|
||||
// wolfe = torch.logical_and(wolfe_1, wolfe_2)
|
||||
wolfe = wolfe_1 & wolfe_2;
|
||||
|
||||
// // step_success = wolfe * (self.alpha_list[:, :, 0:1] + 0.1)
|
||||
// // step_success_w1 = wolfe_1 * (self.alpha_list[:, :, 0:1] + 0.1)
|
||||
// step_success[threadIdx.x] = wolfe * (alpha_list_elem + 0.1);
|
||||
// step_success_w1[threadIdx.x] = wolfe_1 * (alpha_list_elem + 0.1);
|
||||
}
|
||||
unsigned msk1 = __ballot_sync(FULL_MASK, wolfe_1 & condition);
|
||||
unsigned msk = __ballot_sync(FULL_MASK, wolfe & condition);
|
||||
|
||||
// get the index of the last occurance of true
|
||||
unsigned msk1_brev = __brev(msk1);
|
||||
unsigned msk_brev = __brev(msk);
|
||||
|
||||
int id1 = 32 - __ffs(msk1_brev); // position of least signficant bit set to 1
|
||||
int id = 32 - __ffs(msk_brev); // position of least signficant bit set to 1
|
||||
|
||||
__syncthreads();
|
||||
|
||||
__shared__ int idx_shared;
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
if (!approx_wolfe)
|
||||
{
|
||||
if (id == 32) // msk is zero
|
||||
{
|
||||
id = id1;
|
||||
}
|
||||
|
||||
if (id == 0) // bit 0 is set
|
||||
{
|
||||
id = id1;
|
||||
}
|
||||
|
||||
if (id == 32) // msk is zero
|
||||
{
|
||||
id = 1;
|
||||
}
|
||||
|
||||
if (id == 0)
|
||||
{
|
||||
id = 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (id == 32) // msk is zero
|
||||
{
|
||||
id = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// // _, m_idx = torch.max(step_success, dim=-2)
|
||||
// // _, m1_idx = torch.max(step_success_w1, dim=-2)
|
||||
// int m_id = 0;
|
||||
// int m1_id = 0;
|
||||
// scalar_t max1 = step_success[0];
|
||||
// scalar_t max2 = step_success_w1[0];
|
||||
// for (int i=1; i<l1; i++) {
|
||||
// if (max1<step_success[i]) {
|
||||
// max1 = step_success[i];
|
||||
// m_id = i;
|
||||
// }
|
||||
// if (max2<step_success_w1[i]) {
|
||||
// max2 = step_success_w1[i];
|
||||
// m1_id = i;
|
||||
// }
|
||||
// }
|
||||
|
||||
// // m_idx = torch.where(m_idx == 0, m1_idx, m_idx)
|
||||
// if (m_id == 0) {
|
||||
// m_id = m1_id;
|
||||
// }
|
||||
|
||||
// // m_idx[m_idx == 0] = 1
|
||||
// if (m_id == 0) {
|
||||
// m_id = 1;
|
||||
// }
|
||||
|
||||
// if (id != m_id) {
|
||||
// printf("id=%d, m_id=%d\n", id, m_id);
|
||||
// printf("msk1=%x, msk=%x, raw id1=%d, raw id=%d\n", msk1, msk,
|
||||
// 32-__ffs(msk1_brev), 32-__ffs(msk_brev));
|
||||
// }
|
||||
|
||||
// m_idx[batch] = m_id;
|
||||
// m_idx[batch] = id;
|
||||
idx_shared = id + c_idx[batch];
|
||||
}
|
||||
|
||||
////////////////////////////////////
|
||||
// write outputs using the computed index.
|
||||
// one index per batch is computed
|
||||
////////////////////////////////////
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < l2)
|
||||
{
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
// printf("block: %d, idx_shared: %d\n", batch, idx_shared);
|
||||
}
|
||||
best_x[batch * l2 + threadIdx.x] = x_set[idx_shared * l2 + threadIdx.x];
|
||||
best_grad[batch * l2 + threadIdx.x] = g_x[idx_shared * l2 + threadIdx.x];
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
best_c[batch] = c[idx_shared];
|
||||
}
|
||||
}
|
||||
} // namespace Optimization
|
||||
} // namespace Curobo
|
||||
std::vector<torch::Tensor>line_search_cuda(
|
||||
|
||||
// torch::Tensor m_idx,
|
||||
torch::Tensor best_x, torch::Tensor best_c, torch::Tensor best_grad,
|
||||
const torch::Tensor g_x, const torch::Tensor x_set,
|
||||
const torch::Tensor step_vec, const torch::Tensor c_0,
|
||||
const torch::Tensor alpha_list, const torch::Tensor c_idx, const float c_1,
|
||||
const float c_2, const bool strong_wolfe, const bool approx_wolfe,
|
||||
const int l1, const int l2, const int batchsize)
|
||||
{
|
||||
|
||||
int batch = blockIdx.x;
|
||||
__shared__ psum_t data[32];
|
||||
__shared__ scalar_t result[32];
|
||||
assert(l1 <= 32);
|
||||
unsigned mask = __ballot_sync(FULL_MASK, threadIdx.x < l2);
|
||||
|
||||
if (threadIdx.x >= l2) {
|
||||
return;
|
||||
}
|
||||
|
||||
scalar_t sv_elem = step_vec[batch * l2 + threadIdx.x];
|
||||
|
||||
// g_step = g0 @ step_vec_T
|
||||
// g_x @ step_vec_T
|
||||
for (int i = 0; i < l1; i++) {
|
||||
reduce(g_x[batch * l1 * l2 + l2 * i + threadIdx.x] * sv_elem, l2, mask,
|
||||
&data[0], &result[i]);
|
||||
}
|
||||
|
||||
__shared__ scalar_t step_success[32];
|
||||
__shared__ scalar_t step_success_w1[32];
|
||||
assert(blockDim.x >= l1);
|
||||
bool wolfe_1 = false;
|
||||
bool wolfe = false;
|
||||
bool condition = threadIdx.x < l1;
|
||||
if (condition) {
|
||||
// scalar_t alpha_list_elem = alpha_list[batch*l1 + threadIdx.x];
|
||||
scalar_t alpha_list_elem = alpha_list[threadIdx.x];
|
||||
|
||||
// condition 1:
|
||||
wolfe_1 = c[batch * l1 + threadIdx.x] <=
|
||||
(c[batch * l1] + c_1 * alpha_list_elem * result[0]);
|
||||
|
||||
// condition 2:
|
||||
bool wolfe_2;
|
||||
if (strong_wolfe) {
|
||||
wolfe_2 = abs(result[threadIdx.x]) <= c_2 * abs(result[0]);
|
||||
} else {
|
||||
wolfe_2 = result[threadIdx.x] >= c_2 * result[0];
|
||||
}
|
||||
|
||||
wolfe = wolfe_1 & wolfe_2;
|
||||
|
||||
step_success[threadIdx.x] = wolfe * (alpha_list_elem + 0.1);
|
||||
step_success_w1[threadIdx.x] = wolfe_1 * (alpha_list_elem + 0.1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
__shared__ int idx_shared;
|
||||
if (threadIdx.x == 0) {
|
||||
int m_id = 0;
|
||||
int m1_id = 0;
|
||||
scalar_t max1 = step_success[0];
|
||||
scalar_t max2 = step_success_w1[0];
|
||||
for (int i = 1; i < l1; i++) {
|
||||
if (max1 < step_success[i]) {
|
||||
max1 = step_success[i];
|
||||
m_id = i;
|
||||
}
|
||||
if (max2 < step_success_w1[i]) {
|
||||
max2 = step_success_w1[i];
|
||||
m1_id = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (!approx_wolfe) {
|
||||
|
||||
// m_idx = torch.where(m_idx == 0, m1_idx, m_idx)
|
||||
if (m_id == 0) {
|
||||
m_id = m1_id;
|
||||
}
|
||||
// m_idx[m_idx == 0] = 1
|
||||
if (m_id == 0) {
|
||||
m_id = 1;
|
||||
}
|
||||
}
|
||||
idx_shared = m_id + c_idx[batch];
|
||||
}
|
||||
|
||||
////////////////////////////////////
|
||||
// write outputs using the computed index.
|
||||
// one index per batch is computed
|
||||
////////////////////////////////////
|
||||
// l2 is d_opt, l1 is line_search n.
|
||||
// idx_shared contains index in l1
|
||||
//
|
||||
__syncthreads();
|
||||
if (threadIdx.x < l2) {
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("block: %d, idx_shared: %d\n", batch, idx_shared);
|
||||
}
|
||||
best_x[batch * l2 + threadIdx.x] = x_set[idx_shared * l2 + threadIdx.x];
|
||||
best_grad[batch * l2 + threadIdx.x] = g_x[idx_shared * l2 + threadIdx.x];
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
best_c[batch] = c[idx_shared];
|
||||
}
|
||||
}
|
||||
|
||||
// Launched with l2 threads/block and #blocks = batchsize
|
||||
template <typename scalar_t, typename psum_t>
|
||||
__global__ void line_search_kernel_mask(
|
||||
// int64_t *m_idx, // 4x1x1
|
||||
scalar_t *best_x, // 4x280
|
||||
scalar_t *best_c, // 4x1
|
||||
scalar_t *best_grad, // 4x280
|
||||
const scalar_t *g_x, // 4x6x280
|
||||
const scalar_t *x_set, // 4x6x280
|
||||
const scalar_t *step_vec, // 4x280x1
|
||||
const scalar_t *c, // 4x6x1
|
||||
const scalar_t *alpha_list, // 4x6x1
|
||||
const int64_t *c_idx, // 4x1x1
|
||||
const float c_1, const float c_2, const bool strong_wolfe,
|
||||
const bool approx_wolfe,
|
||||
const int l1, // 6
|
||||
const int l2, // 280
|
||||
const int batchsize) // 4
|
||||
{
|
||||
|
||||
int batch = blockIdx.x;
|
||||
__shared__ psum_t data[32];
|
||||
__shared__ scalar_t result[32];
|
||||
assert(l1 <= 32);
|
||||
unsigned mask = __ballot_sync(FULL_MASK, threadIdx.x < l2);
|
||||
|
||||
if (threadIdx.x >= l2) {
|
||||
return;
|
||||
}
|
||||
|
||||
scalar_t sv_elem = step_vec[batch * l2 + threadIdx.x];
|
||||
|
||||
// g_step = g0 @ step_vec_T
|
||||
// g_x @ step_vec_T
|
||||
for (int i = 0; i < l1; i++) {
|
||||
reduce(g_x[batch * l1 * l2 + l2 * i + threadIdx.x] * sv_elem, l2, mask,
|
||||
&data[0], &result[i]);
|
||||
}
|
||||
|
||||
// __shared__ scalar_t step_success[32];
|
||||
// __shared__ scalar_t step_success_w1[32];
|
||||
assert(blockDim.x >= l1);
|
||||
bool wolfe_1 = false;
|
||||
bool wolfe = false;
|
||||
bool condition = threadIdx.x < l1;
|
||||
if (condition) {
|
||||
scalar_t alpha_list_elem = alpha_list[threadIdx.x];
|
||||
|
||||
// scalar_t alpha_list_elem = alpha_list[batch*l1 + threadIdx.x];
|
||||
|
||||
// condition 1:
|
||||
wolfe_1 = c[batch * l1 + threadIdx.x] <=
|
||||
(c[batch * l1] + c_1 * alpha_list_elem * result[0]);
|
||||
|
||||
// condition 2:
|
||||
bool wolfe_2;
|
||||
if (strong_wolfe) {
|
||||
wolfe_2 = abs(result[threadIdx.x]) <= c_2 * abs(result[0]);
|
||||
} else {
|
||||
wolfe_2 = result[threadIdx.x] >= c_2 * result[0];
|
||||
}
|
||||
|
||||
// wolfe = torch.logical_and(wolfe_1, wolfe_2)
|
||||
wolfe = wolfe_1 & wolfe_2;
|
||||
|
||||
// // step_success = wolfe * (self.alpha_list[:, :, 0:1] + 0.1)
|
||||
// // step_success_w1 = wolfe_1 * (self.alpha_list[:, :, 0:1] + 0.1)
|
||||
// step_success[threadIdx.x] = wolfe * (alpha_list_elem + 0.1);
|
||||
// step_success_w1[threadIdx.x] = wolfe_1 * (alpha_list_elem + 0.1);
|
||||
}
|
||||
unsigned msk1 = __ballot_sync(FULL_MASK, wolfe_1 & condition);
|
||||
unsigned msk = __ballot_sync(FULL_MASK, wolfe & condition);
|
||||
|
||||
// get the index of the last occurance of true
|
||||
unsigned msk1_brev = __brev(msk1);
|
||||
unsigned msk_brev = __brev(msk);
|
||||
|
||||
int id1 = 32 - __ffs(msk1_brev); // position of least signficant bit set to 1
|
||||
int id = 32 - __ffs(msk_brev); // position of least signficant bit set to 1
|
||||
|
||||
__syncthreads();
|
||||
|
||||
__shared__ int idx_shared;
|
||||
if (threadIdx.x == 0) {
|
||||
if (!approx_wolfe) {
|
||||
if (id == 32) { // msk is zero
|
||||
id = id1;
|
||||
}
|
||||
if (id == 0) { // bit 0 is set
|
||||
id = id1;
|
||||
}
|
||||
if (id == 32) { // msk is zero
|
||||
id = 1;
|
||||
}
|
||||
if (id == 0) {
|
||||
id = 1;
|
||||
}
|
||||
} else {
|
||||
if (id == 32) { // msk is zero
|
||||
id = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// // _, m_idx = torch.max(step_success, dim=-2)
|
||||
// // _, m1_idx = torch.max(step_success_w1, dim=-2)
|
||||
// int m_id = 0;
|
||||
// int m1_id = 0;
|
||||
// scalar_t max1 = step_success[0];
|
||||
// scalar_t max2 = step_success_w1[0];
|
||||
// for (int i=1; i<l1; i++) {
|
||||
// if (max1<step_success[i]) {
|
||||
// max1 = step_success[i];
|
||||
// m_id = i;
|
||||
// }
|
||||
// if (max2<step_success_w1[i]) {
|
||||
// max2 = step_success_w1[i];
|
||||
// m1_id = i;
|
||||
// }
|
||||
// }
|
||||
|
||||
// // m_idx = torch.where(m_idx == 0, m1_idx, m_idx)
|
||||
// if (m_id == 0) {
|
||||
// m_id = m1_id;
|
||||
// }
|
||||
|
||||
// // m_idx[m_idx == 0] = 1
|
||||
// if (m_id == 0) {
|
||||
// m_id = 1;
|
||||
// }
|
||||
|
||||
// if (id != m_id) {
|
||||
// printf("id=%d, m_id=%d\n", id, m_id);
|
||||
// printf("msk1=%x, msk=%x, raw id1=%d, raw id=%d\n", msk1, msk,
|
||||
// 32-__ffs(msk1_brev), 32-__ffs(msk_brev));
|
||||
// }
|
||||
|
||||
// m_idx[batch] = m_id;
|
||||
// m_idx[batch] = id;
|
||||
idx_shared = id + c_idx[batch];
|
||||
}
|
||||
|
||||
////////////////////////////////////
|
||||
// write outputs using the computed index.
|
||||
// one index per batch is computed
|
||||
////////////////////////////////////
|
||||
__syncthreads();
|
||||
if (threadIdx.x < l2) {
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("block: %d, idx_shared: %d\n", batch, idx_shared);
|
||||
}
|
||||
best_x[batch * l2 + threadIdx.x] = x_set[idx_shared * l2 + threadIdx.x];
|
||||
best_grad[batch * l2 + threadIdx.x] = g_x[idx_shared * l2 + threadIdx.x];
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
best_c[batch] = c[idx_shared];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Optimization
|
||||
} // namespace Curobo
|
||||
std::vector<torch::Tensor> line_search_cuda(
|
||||
// torch::Tensor m_idx,
|
||||
torch::Tensor best_x, torch::Tensor best_c, torch::Tensor best_grad,
|
||||
const torch::Tensor g_x, const torch::Tensor x_set,
|
||||
const torch::Tensor step_vec, const torch::Tensor c_0,
|
||||
const torch::Tensor alpha_list, const torch::Tensor c_idx, const float c_1,
|
||||
const float c_2, const bool strong_wolfe, const bool approx_wolfe,
|
||||
const int l1, const int l2, const int batchsize) {
|
||||
using namespace Curobo::Optimization;
|
||||
assert(l2 <= 1024);
|
||||
|
||||
// multiple of 32
|
||||
const int threadsPerBlock = 32 * ((l2 + 31) / 32); // l2;
|
||||
const int blocksPerGrid = batchsize;
|
||||
const int blocksPerGrid = batchsize;
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
g_x.scalar_type(), "line_search_cu", ([&] {
|
||||
line_search_kernel_mask<scalar_t, scalar_t>
|
||||
<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
|
||||
// m_idx.data_ptr<int>(),
|
||||
best_x.data_ptr<scalar_t>(), best_c.data_ptr<scalar_t>(),
|
||||
best_grad.data_ptr<scalar_t>(), g_x.data_ptr<scalar_t>(),
|
||||
x_set.data_ptr<scalar_t>(), step_vec.data_ptr<scalar_t>(),
|
||||
c_0.data_ptr<scalar_t>(), alpha_list.data_ptr<scalar_t>(),
|
||||
c_idx.data_ptr<int64_t>(), c_1, c_2, strong_wolfe, approx_wolfe,
|
||||
l1, l2, batchsize);
|
||||
}));
|
||||
g_x.scalar_type(), "line_search_cu", ([&] {
|
||||
line_search_kernel_mask<scalar_t, scalar_t>
|
||||
<< < blocksPerGrid, threadsPerBlock, 0, stream >> > (
|
||||
|
||||
// m_idx.data_ptr<int>(),
|
||||
best_x.data_ptr<scalar_t>(), best_c.data_ptr<scalar_t>(),
|
||||
best_grad.data_ptr<scalar_t>(), g_x.data_ptr<scalar_t>(),
|
||||
x_set.data_ptr<scalar_t>(), step_vec.data_ptr<scalar_t>(),
|
||||
c_0.data_ptr<scalar_t>(), alpha_list.data_ptr<scalar_t>(),
|
||||
c_idx.data_ptr<int64_t>(), c_1, c_2, strong_wolfe, approx_wolfe,
|
||||
l1, l2, batchsize);
|
||||
}));
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
return {best_x, best_c, best_grad};
|
||||
}
|
||||
return { best_x, best_c, best_grad };
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -13,72 +13,120 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <vector>
|
||||
|
||||
std::vector<torch::Tensor> step_position_clique(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor traj_dt, const int batch_size, const int horizon,
|
||||
const int dof);
|
||||
std::vector<torch::Tensor> step_position_clique2(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor traj_dt, const int batch_size, const int horizon,
|
||||
const int dof, const int mode);
|
||||
std::vector<torch::Tensor> step_position_clique2_idx(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor start_idx, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof, const int mode);
|
||||
std::vector<torch::Tensor>step_position_clique(
|
||||
torch::Tensor out_position,
|
||||
torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration,
|
||||
torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position,
|
||||
const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity,
|
||||
const torch::Tensor start_acceleration,
|
||||
const torch::Tensor traj_dt,
|
||||
const int batch_size,
|
||||
const int horizon,
|
||||
const int dof);
|
||||
std::vector<torch::Tensor>step_position_clique2(
|
||||
torch::Tensor out_position,
|
||||
torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration,
|
||||
torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position,
|
||||
const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity,
|
||||
const torch::Tensor start_acceleration,
|
||||
const torch::Tensor traj_dt,
|
||||
const int batch_size,
|
||||
const int horizon,
|
||||
const int dof,
|
||||
const int mode);
|
||||
std::vector<torch::Tensor>step_position_clique2_idx(
|
||||
torch::Tensor out_position,
|
||||
torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration,
|
||||
torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position,
|
||||
const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity,
|
||||
const torch::Tensor start_acceleration,
|
||||
const torch::Tensor start_idx,
|
||||
const torch::Tensor traj_dt,
|
||||
const int batch_size,
|
||||
const int horizon,
|
||||
const int dof,
|
||||
const int mode);
|
||||
|
||||
std::vector<torch::Tensor> backward_step_position_clique(
|
||||
torch::Tensor out_grad_position, const torch::Tensor grad_position,
|
||||
const torch::Tensor grad_velocity, const torch::Tensor grad_acceleration,
|
||||
const torch::Tensor grad_jerk, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof);
|
||||
std::vector<torch::Tensor> backward_step_position_clique2(
|
||||
torch::Tensor out_grad_position, const torch::Tensor grad_position,
|
||||
const torch::Tensor grad_velocity, const torch::Tensor grad_acceleration,
|
||||
const torch::Tensor grad_jerk, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof, const int mode);
|
||||
std::vector<torch::Tensor>backward_step_position_clique(
|
||||
torch::Tensor out_grad_position,
|
||||
const torch::Tensor grad_position,
|
||||
const torch::Tensor grad_velocity,
|
||||
const torch::Tensor grad_acceleration,
|
||||
const torch::Tensor grad_jerk,
|
||||
const torch::Tensor traj_dt,
|
||||
const int batch_size,
|
||||
const int horizon,
|
||||
const int dof);
|
||||
std::vector<torch::Tensor>backward_step_position_clique2(
|
||||
torch::Tensor out_grad_position,
|
||||
const torch::Tensor grad_position,
|
||||
const torch::Tensor grad_velocity,
|
||||
const torch::Tensor grad_acceleration,
|
||||
const torch::Tensor grad_jerk,
|
||||
const torch::Tensor traj_dt,
|
||||
const int batch_size,
|
||||
const int horizon,
|
||||
const int dof,
|
||||
const int mode);
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
step_acceleration(torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_acc, const torch::Tensor start_position,
|
||||
step_acceleration(torch::Tensor out_position,
|
||||
torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration,
|
||||
torch::Tensor out_jerk,
|
||||
const torch::Tensor u_acc,
|
||||
const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity,
|
||||
const torch::Tensor start_acceleration,
|
||||
const torch::Tensor traj_dt, const int batch_size,
|
||||
const int horizon, const int dof, const bool use_rk2 = true);
|
||||
const torch::Tensor traj_dt,
|
||||
const int batch_size,
|
||||
const int horizon,
|
||||
const int dof,
|
||||
const bool use_rk2 = true);
|
||||
|
||||
std::vector<torch::Tensor>step_acceleration_idx(
|
||||
torch::Tensor out_position,
|
||||
torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration,
|
||||
torch::Tensor out_jerk,
|
||||
const torch::Tensor u_acc,
|
||||
const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity,
|
||||
const torch::Tensor start_acceleration,
|
||||
const torch::Tensor start_idx,
|
||||
const torch::Tensor traj_dt,
|
||||
const int batch_size,
|
||||
const int horizon,
|
||||
const int dof,
|
||||
const bool use_rk2 = true);
|
||||
|
||||
std::vector<torch::Tensor> step_acceleration_idx(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_acc, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor start_idx, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof,
|
||||
const bool use_rk2 = true);
|
||||
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), # x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), # x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<torch::Tensor> step_position_clique_wrapper(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor traj_dt, const int batch_size, const int horizon,
|
||||
const int dof) {
|
||||
std::vector<torch::Tensor>step_position_clique_wrapper(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor traj_dt, const int batch_size, const int horizon,
|
||||
const int dof)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(u_position.device());
|
||||
|
||||
assert(false); // not supported
|
||||
CHECK_INPUT(u_position);
|
||||
CHECK_INPUT(out_position);
|
||||
@@ -96,14 +144,15 @@ std::vector<torch::Tensor> step_position_clique_wrapper(
|
||||
batch_size, horizon, dof);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> step_position_clique2_wrapper(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor traj_dt, const int batch_size, const int horizon,
|
||||
const int dof,
|
||||
const int mode) {
|
||||
std::vector<torch::Tensor>step_position_clique2_wrapper(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor traj_dt, const int batch_size, const int horizon,
|
||||
const int dof,
|
||||
const int mode)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(u_position.device());
|
||||
|
||||
CHECK_INPUT(u_position);
|
||||
@@ -122,14 +171,15 @@ std::vector<torch::Tensor> step_position_clique2_wrapper(
|
||||
batch_size, horizon, dof, mode);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> step_position_clique2_idx_wrapper(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor start_idx, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof,
|
||||
const int mode) {
|
||||
std::vector<torch::Tensor>step_position_clique2_idx_wrapper(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_position, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor start_idx, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof,
|
||||
const int mode)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(u_position.device());
|
||||
|
||||
CHECK_INPUT(u_position);
|
||||
@@ -144,17 +194,19 @@ std::vector<torch::Tensor> step_position_clique2_idx_wrapper(
|
||||
CHECK_INPUT(start_idx);
|
||||
|
||||
return step_position_clique2_idx(
|
||||
out_position, out_velocity, out_acceleration, out_jerk, u_position,
|
||||
start_position, start_velocity, start_acceleration, start_idx, traj_dt,
|
||||
batch_size, horizon, dof, mode);
|
||||
out_position, out_velocity, out_acceleration, out_jerk, u_position,
|
||||
start_position, start_velocity, start_acceleration, start_idx, traj_dt,
|
||||
batch_size, horizon, dof, mode);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> backward_step_position_clique_wrapper(
|
||||
torch::Tensor out_grad_position, const torch::Tensor grad_position,
|
||||
const torch::Tensor grad_velocity, const torch::Tensor grad_acceleration,
|
||||
const torch::Tensor grad_jerk, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof) {
|
||||
std::vector<torch::Tensor>backward_step_position_clique_wrapper(
|
||||
torch::Tensor out_grad_position, const torch::Tensor grad_position,
|
||||
const torch::Tensor grad_velocity, const torch::Tensor grad_acceleration,
|
||||
const torch::Tensor grad_jerk, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(grad_position.device());
|
||||
|
||||
assert(false); // not supported
|
||||
CHECK_INPUT(out_grad_position);
|
||||
CHECK_INPUT(grad_position);
|
||||
@@ -164,15 +216,17 @@ std::vector<torch::Tensor> backward_step_position_clique_wrapper(
|
||||
CHECK_INPUT(traj_dt);
|
||||
|
||||
return backward_step_position_clique(
|
||||
out_grad_position, grad_position, grad_velocity, grad_acceleration,
|
||||
grad_jerk, traj_dt, batch_size, horizon, dof);
|
||||
out_grad_position, grad_position, grad_velocity, grad_acceleration,
|
||||
grad_jerk, traj_dt, batch_size, horizon, dof);
|
||||
}
|
||||
std::vector<torch::Tensor> backward_step_position_clique2_wrapper(
|
||||
torch::Tensor out_grad_position, const torch::Tensor grad_position,
|
||||
const torch::Tensor grad_velocity, const torch::Tensor grad_acceleration,
|
||||
const torch::Tensor grad_jerk, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof,
|
||||
const int mode) {
|
||||
|
||||
std::vector<torch::Tensor>backward_step_position_clique2_wrapper(
|
||||
torch::Tensor out_grad_position, const torch::Tensor grad_position,
|
||||
const torch::Tensor grad_velocity, const torch::Tensor grad_acceleration,
|
||||
const torch::Tensor grad_jerk, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof,
|
||||
const int mode)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(grad_position.device());
|
||||
|
||||
CHECK_INPUT(out_grad_position);
|
||||
@@ -183,17 +237,18 @@ std::vector<torch::Tensor> backward_step_position_clique2_wrapper(
|
||||
CHECK_INPUT(traj_dt);
|
||||
|
||||
return backward_step_position_clique2(
|
||||
out_grad_position, grad_position, grad_velocity, grad_acceleration,
|
||||
grad_jerk, traj_dt, batch_size, horizon, dof, mode);
|
||||
out_grad_position, grad_position, grad_velocity, grad_acceleration,
|
||||
grad_jerk, traj_dt, batch_size, horizon, dof, mode);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> step_acceleration_wrapper(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_acc, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor traj_dt, const int batch_size, const int horizon,
|
||||
const int dof, const bool use_rk2 = true) {
|
||||
std::vector<torch::Tensor>step_acceleration_wrapper(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_acc, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor traj_dt, const int batch_size, const int horizon,
|
||||
const int dof, const bool use_rk2 = true)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(u_acc.device());
|
||||
|
||||
CHECK_INPUT(u_acc);
|
||||
@@ -212,14 +267,15 @@ std::vector<torch::Tensor> step_acceleration_wrapper(
|
||||
dof, use_rk2);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> step_acceleration_idx_wrapper(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_acc, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor start_idx, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof,
|
||||
const bool use_rk2 = true) {
|
||||
std::vector<torch::Tensor>step_acceleration_idx_wrapper(
|
||||
torch::Tensor out_position, torch::Tensor out_velocity,
|
||||
torch::Tensor out_acceleration, torch::Tensor out_jerk,
|
||||
const torch::Tensor u_acc, const torch::Tensor start_position,
|
||||
const torch::Tensor start_velocity, const torch::Tensor start_acceleration,
|
||||
const torch::Tensor start_idx, const torch::Tensor traj_dt,
|
||||
const int batch_size, const int horizon, const int dof,
|
||||
const bool use_rk2 = true)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(u_acc.device());
|
||||
|
||||
CHECK_INPUT(u_acc);
|
||||
@@ -239,19 +295,20 @@ std::vector<torch::Tensor> step_acceleration_idx_wrapper(
|
||||
batch_size, horizon, dof, use_rk2);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("step_position", &step_position_clique_wrapper,
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("step_position", &step_position_clique_wrapper,
|
||||
"Tensor Step Position (curobolib)");
|
||||
m.def("step_position2", &step_position_clique2_wrapper,
|
||||
m.def("step_position2", &step_position_clique2_wrapper,
|
||||
"Tensor Step Position (curobolib)");
|
||||
m.def("step_idx_position2", &step_position_clique2_idx_wrapper,
|
||||
m.def("step_idx_position2", &step_position_clique2_idx_wrapper,
|
||||
"Tensor Step Position (curobolib)");
|
||||
m.def("step_position_backward", &backward_step_position_clique_wrapper,
|
||||
m.def("step_position_backward", &backward_step_position_clique_wrapper,
|
||||
"Tensor Step Position (curobolib)");
|
||||
m.def("step_position_backward2", &backward_step_position_clique2_wrapper,
|
||||
"Tensor Step Position (curobolib)");
|
||||
m.def("step_acceleration", &step_acceleration_wrapper,
|
||||
m.def("step_acceleration", &step_acceleration_wrapper,
|
||||
"Tensor Step Acceleration (curobolib)");
|
||||
m.def("step_acceleration_idx", &step_acceleration_idx_wrapper,
|
||||
m.def("step_acceleration_idx", &step_acceleration_idx_wrapper,
|
||||
"Tensor Step Acceleration (curobolib)");
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,101 +26,108 @@
|
||||
#include <cub/cub.cuh>
|
||||
#include <math.h>
|
||||
|
||||
namespace Curobo {
|
||||
|
||||
namespace Optimization {
|
||||
|
||||
// We launch with d_opt*cost_s1 threads.
|
||||
// We assume that cost_s2 is always 1.
|
||||
template <typename scalar_t>
|
||||
__global__ void update_best_kernel(scalar_t *best_cost, // 200x1
|
||||
scalar_t *best_q, // 200x7
|
||||
int16_t *best_iteration, // 200 x 1
|
||||
int16_t *current_iteration, // 1
|
||||
const scalar_t *cost, // 200x1
|
||||
const scalar_t *q, // 200x7
|
||||
const int d_opt, // 7
|
||||
const int cost_s1, // 200
|
||||
const int cost_s2,
|
||||
const int iteration,
|
||||
const float delta_threshold,
|
||||
const float relative_threshold) // 1
|
||||
namespace Curobo
|
||||
{
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int size = cost_s1 * d_opt; // size of best_q
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int cost_idx = tid / d_opt;
|
||||
const float cost_new = cost[cost_idx];
|
||||
const float best_cost_in = best_cost[cost_idx];
|
||||
const bool change = (best_cost_in - cost_new) > delta_threshold && cost_new < best_cost_in * relative_threshold;
|
||||
if (change) {
|
||||
best_q[tid] = q[tid]; // update best_q
|
||||
|
||||
if (tid % d_opt == 0) {
|
||||
best_cost[cost_idx] = cost_new ; // update best_cost
|
||||
//best_iteration[cost_idx] = curr_iter + iteration; //
|
||||
// this tensor keeps track of whether the cost reduced by at least
|
||||
// delta_threshold.
|
||||
// here iteration is the last_best parameter.
|
||||
}
|
||||
}
|
||||
|
||||
if (tid % d_opt == 0)
|
||||
namespace Optimization
|
||||
{
|
||||
if (change)
|
||||
// We launch with d_opt*cost_s1 threads.
|
||||
// We assume that cost_s2 is always 1.
|
||||
template<typename scalar_t>
|
||||
__global__ void update_best_kernel(scalar_t *best_cost, // 200x1
|
||||
scalar_t *best_q, // 200x7
|
||||
int16_t *best_iteration, // 200 x 1
|
||||
int16_t *current_iteration, // 1
|
||||
const scalar_t *cost, // 200x1
|
||||
const scalar_t *q, // 200x7
|
||||
const int d_opt, // 7
|
||||
const int cost_s1, // 200
|
||||
const int cost_s2,
|
||||
const int iteration,
|
||||
const float delta_threshold,
|
||||
const float relative_threshold) // 1
|
||||
{
|
||||
best_iteration[cost_idx] = 0;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int size = cost_s1 * d_opt; // size of best_q
|
||||
|
||||
if (tid >= size)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const int cost_idx = tid / d_opt;
|
||||
const float cost_new = cost[cost_idx];
|
||||
const float best_cost_in = best_cost[cost_idx];
|
||||
const bool change = (best_cost_in - cost_new) > delta_threshold &&
|
||||
cost_new < best_cost_in * relative_threshold;
|
||||
|
||||
if (change)
|
||||
{
|
||||
best_q[tid] = q[tid]; // update best_q
|
||||
|
||||
if (tid % d_opt == 0)
|
||||
{
|
||||
best_cost[cost_idx] = cost_new; // update best_cost
|
||||
// best_iteration[cost_idx] = curr_iter + iteration; //
|
||||
// this tensor keeps track of whether the cost reduced by at least
|
||||
// delta_threshold.
|
||||
// here iteration is the last_best parameter.
|
||||
}
|
||||
}
|
||||
|
||||
if (tid % d_opt == 0)
|
||||
{
|
||||
if (change)
|
||||
{
|
||||
best_iteration[cost_idx] = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
best_iteration[cost_idx] -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// .if (tid == 0)
|
||||
// {
|
||||
// curr_iter += 1;
|
||||
// current_iteration[0] = curr_iter;
|
||||
// }
|
||||
}
|
||||
else
|
||||
{
|
||||
best_iteration[cost_idx] -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
//.if (tid == 0)
|
||||
//{
|
||||
// curr_iter += 1;
|
||||
// current_iteration[0] = curr_iter;
|
||||
//}
|
||||
|
||||
}
|
||||
} // namespace Optimization
|
||||
} // namespace Curobo
|
||||
} // namespace Optimization
|
||||
} // namespace Curobo
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
update_best_cuda(torch::Tensor best_cost, torch::Tensor best_q,
|
||||
torch::Tensor best_iteration,
|
||||
torch::Tensor current_iteration,
|
||||
const torch::Tensor cost,
|
||||
const torch::Tensor cost,
|
||||
const torch::Tensor q, const int d_opt, const int cost_s1,
|
||||
const int cost_s2, const int iteration,
|
||||
const float delta_threshold,
|
||||
const float relative_threshold = 0.999) {
|
||||
const float relative_threshold = 0.999)
|
||||
{
|
||||
using namespace Curobo::Optimization;
|
||||
const int threadsPerBlock = 128;
|
||||
const int cost_size = cost_s1 * d_opt;
|
||||
const int cost_size = cost_s1 * d_opt;
|
||||
assert(cost_s2 == 1); // assumption
|
||||
const int blocksPerGrid = (cost_size + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
// printf("cost_s1=%d, d_opt=%d, blocksPerGrid=%d\n", cost_s1, d_opt,
|
||||
// blocksPerGrid);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
cost.scalar_type(), "update_best_cu", ([&] {
|
||||
update_best_kernel<scalar_t>
|
||||
<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
|
||||
best_cost.data_ptr<scalar_t>(), best_q.data_ptr<scalar_t>(),
|
||||
best_iteration.data_ptr<int16_t>(),
|
||||
current_iteration.data_ptr<int16_t>(),
|
||||
cost.data_ptr<scalar_t>(),
|
||||
q.data_ptr<scalar_t>(), d_opt, cost_s1, cost_s2, iteration,
|
||||
delta_threshold, relative_threshold);
|
||||
}));
|
||||
cost.scalar_type(), "update_best_cu", ([&] {
|
||||
update_best_kernel<scalar_t>
|
||||
<< < blocksPerGrid, threadsPerBlock, 0, stream >> > (
|
||||
best_cost.data_ptr<scalar_t>(), best_q.data_ptr<scalar_t>(),
|
||||
best_iteration.data_ptr<int16_t>(),
|
||||
current_iteration.data_ptr<int16_t>(),
|
||||
cost.data_ptr<scalar_t>(),
|
||||
q.data_ptr<scalar_t>(), d_opt, cost_s1, cost_s2, iteration,
|
||||
delta_threshold, relative_threshold);
|
||||
}));
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
return {best_cost, best_q, best_iteration};
|
||||
}
|
||||
return { best_cost, best_q, best_iteration };
|
||||
}
|
||||
|
||||
@@ -153,6 +153,8 @@ def get_pose_distance(
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
batch_size,
|
||||
horizon,
|
||||
@@ -161,6 +163,7 @@ def get_pose_distance(
|
||||
write_grad=False,
|
||||
write_distance=False,
|
||||
use_metric=False,
|
||||
project_distance=True,
|
||||
):
|
||||
if batch_pose_idx.shape[0] != batch_size:
|
||||
raise ValueError("Index buffer size is different from batch size")
|
||||
@@ -181,6 +184,8 @@ def get_pose_distance(
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
batch_size,
|
||||
horizon,
|
||||
@@ -189,6 +194,7 @@ def get_pose_distance(
|
||||
write_grad,
|
||||
write_distance,
|
||||
use_metric,
|
||||
project_distance,
|
||||
)
|
||||
|
||||
out_distance = r[0]
|
||||
@@ -229,6 +235,331 @@ def get_pose_distance_backward(
|
||||
return r[0], r[1]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def backward_PoseError_jit(grad_g_dist, grad_out_distance, weight, g_vec):
|
||||
grad_vec = grad_g_dist + (grad_out_distance * weight)
|
||||
grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec
|
||||
return grad
|
||||
|
||||
|
||||
# full method:
|
||||
@torch.jit.script
|
||||
def backward_full_PoseError_jit(
|
||||
grad_out_distance, grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
|
||||
):
|
||||
p_grad = (grad_g_dist + (grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
|
||||
q_grad = (grad_r_err + (grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
|
||||
# p_grad = ((grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
|
||||
# q_grad = ((grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
|
||||
|
||||
return p_grad, q_grad
|
||||
|
||||
|
||||
class PoseErrorDistance(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
current_position,
|
||||
goal_position,
|
||||
current_quat,
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
out_p_grad,
|
||||
out_q_grad,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode, # =PoseErrorType.BATCH_GOAL.value,
|
||||
num_goals,
|
||||
use_metric, # =False,
|
||||
project_distance, # =True,
|
||||
):
|
||||
# out_distance = current_position[..., 0].detach().clone() * 0.0
|
||||
# out_position_distance = out_distance.detach().clone()
|
||||
# out_rotation_distance = out_distance.detach().clone()
|
||||
# out_vec = (
|
||||
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
|
||||
# * 0.0
|
||||
# )
|
||||
# out_idx = out_distance.clone().to(dtype=torch.long)
|
||||
|
||||
(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
) = get_pose_distance(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
current_position.contiguous(),
|
||||
goal_position,
|
||||
current_quat.contiguous(),
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode,
|
||||
num_goals,
|
||||
current_position.requires_grad,
|
||||
True,
|
||||
use_metric,
|
||||
project_distance,
|
||||
)
|
||||
ctx.save_for_backward(out_p_vec, out_r_vec, weight, out_p_grad, out_q_grad)
|
||||
return out_distance, out_position_distance, out_rotation_distance, out_idx # .view(-1,1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out_distance, grad_g_dist, grad_r_err, grad_out_idx):
|
||||
(g_vec_p, g_vec_q, weight, out_grad_p, out_grad_q) = ctx.saved_tensors
|
||||
pos_grad = None
|
||||
quat_grad = None
|
||||
batch_size = g_vec_p.shape[0] * g_vec_p.shape[1]
|
||||
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
|
||||
pos_grad, quat_grad = get_pose_distance_backward(
|
||||
out_grad_p,
|
||||
out_grad_q,
|
||||
grad_out_distance.contiguous(),
|
||||
grad_g_dist.contiguous(),
|
||||
grad_r_err.contiguous(),
|
||||
weight,
|
||||
g_vec_p,
|
||||
g_vec_q,
|
||||
batch_size,
|
||||
use_distance=True,
|
||||
)
|
||||
|
||||
elif ctx.needs_input_grad[0]:
|
||||
pos_grad = backward_PoseError_jit(grad_g_dist, grad_out_distance, weight[1], g_vec_p)
|
||||
|
||||
elif ctx.needs_input_grad[2]:
|
||||
quat_grad = backward_PoseError_jit(grad_r_err, grad_out_distance, weight[0], g_vec_q)
|
||||
|
||||
return (
|
||||
pos_grad,
|
||||
None,
|
||||
quat_grad,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class PoseError(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
current_position: torch.Tensor,
|
||||
goal_position: torch.Tensor,
|
||||
current_quat: torch.Tensor,
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
out_p_grad,
|
||||
out_q_grad,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode,
|
||||
num_goals,
|
||||
use_metric,
|
||||
project_distance,
|
||||
return_loss,
|
||||
):
|
||||
"""Compute error in pose
|
||||
|
||||
_extended_summary_
|
||||
|
||||
Args:
|
||||
ctx: _description_
|
||||
current_position: _description_
|
||||
goal_position: _description_
|
||||
current_quat: _description_
|
||||
goal_quat: _description_
|
||||
vec_weight: _description_
|
||||
weight: _description_
|
||||
vec_convergence: _description_
|
||||
run_weight: _description_
|
||||
run_vec_weight: _description_
|
||||
offset_waypoint: _description_
|
||||
offset_tstep_fraction: _description_
|
||||
batch_pose_idx: _description_
|
||||
out_distance: _description_
|
||||
out_position_distance: _description_
|
||||
out_rotation_distance: _description_
|
||||
out_p_vec: _description_
|
||||
out_r_vec: _description_
|
||||
out_idx: _description_
|
||||
out_p_grad: _description_
|
||||
out_q_grad: _description_
|
||||
batch_size: _description_
|
||||
horizon: _description_
|
||||
mode: _description_
|
||||
num_goals: _description_
|
||||
use_metric: _description_
|
||||
project_distance: _description_
|
||||
return_loss: _description_
|
||||
|
||||
Returns:
|
||||
_description_
|
||||
"""
|
||||
# out_distance = current_position[..., 0].detach().clone() * 0.0
|
||||
# out_position_distance = out_distance.detach().clone()
|
||||
# out_rotation_distance = out_distance.detach().clone()
|
||||
# out_vec = (
|
||||
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
|
||||
# * 0.0
|
||||
# )
|
||||
# out_idx = out_distance.clone().to(dtype=torch.long)
|
||||
ctx.return_loss = return_loss
|
||||
(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
) = get_pose_distance(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
current_position.contiguous(),
|
||||
goal_position,
|
||||
current_quat.contiguous(),
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode,
|
||||
num_goals,
|
||||
current_position.requires_grad,
|
||||
False,
|
||||
use_metric,
|
||||
project_distance,
|
||||
)
|
||||
ctx.save_for_backward(out_p_vec, out_r_vec)
|
||||
return out_distance
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out_distance): # , grad_g_dist, grad_r_err, grad_out_idx):
|
||||
pos_grad = None
|
||||
quat_grad = None
|
||||
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
pos_grad = g_vec_p
|
||||
quat_grad = g_vec_q
|
||||
if ctx.return_loss:
|
||||
pos_grad = pos_grad * grad_out_distance.unsqueeze(1)
|
||||
quat_grad = quat_grad * grad_out_distance.unsqueeze(1)
|
||||
|
||||
elif ctx.needs_input_grad[0]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
|
||||
pos_grad = g_vec_p
|
||||
if ctx.return_loss:
|
||||
pos_grad = pos_grad * grad_out_distance.unsqueeze(1)
|
||||
elif ctx.needs_input_grad[2]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
|
||||
quat_grad = g_vec_q
|
||||
if ctx.return_loss:
|
||||
quat_grad = quat_grad * grad_out_distance.unsqueeze(1)
|
||||
return (
|
||||
pos_grad,
|
||||
None,
|
||||
quat_grad,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class SdfSphereOBB(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user