constrained planning, robot segmentation
This commit is contained in:
@@ -16,91 +16,112 @@
|
||||
// CUDA forward declarations
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
matrix_to_quaternion(torch::Tensor out_quat,
|
||||
matrix_to_quaternion(torch::Tensor out_quat,
|
||||
const torch::Tensor in_rot // batch_size, 3
|
||||
);
|
||||
);
|
||||
|
||||
std::vector<torch::Tensor> kin_fused_forward(
|
||||
torch::Tensor link_pos, torch::Tensor link_quat,
|
||||
torch::Tensor batch_robot_spheres, torch::Tensor global_cumul_mat,
|
||||
const torch::Tensor joint_vec, const torch::Tensor fixed_transform,
|
||||
const torch::Tensor robot_spheres, const torch::Tensor link_map,
|
||||
const torch::Tensor joint_map, const torch::Tensor joint_map_type,
|
||||
const torch::Tensor store_link_map, const torch::Tensor link_sphere_map,
|
||||
const int batch_size, const int n_spheres,
|
||||
const bool use_global_cumul = false);
|
||||
std::vector<torch::Tensor>kin_fused_forward(
|
||||
torch::Tensor link_pos,
|
||||
torch::Tensor link_quat,
|
||||
torch::Tensor batch_robot_spheres,
|
||||
torch::Tensor global_cumul_mat,
|
||||
const torch::Tensor joint_vec,
|
||||
const torch::Tensor fixed_transform,
|
||||
const torch::Tensor robot_spheres,
|
||||
const torch::Tensor link_map,
|
||||
const torch::Tensor joint_map,
|
||||
const torch::Tensor joint_map_type,
|
||||
const torch::Tensor store_link_map,
|
||||
const torch::Tensor link_sphere_map,
|
||||
const int batch_size,
|
||||
const int n_spheres,
|
||||
const bool use_global_cumul = false);
|
||||
|
||||
std::vector<torch::Tensor>kin_fused_backward_16t(
|
||||
torch::Tensor grad_out,
|
||||
const torch::Tensor grad_nlinks_pos,
|
||||
const torch::Tensor grad_nlinks_quat,
|
||||
const torch::Tensor grad_spheres,
|
||||
const torch::Tensor global_cumul_mat,
|
||||
const torch::Tensor joint_vec,
|
||||
const torch::Tensor fixed_transform,
|
||||
const torch::Tensor robot_spheres,
|
||||
const torch::Tensor link_map,
|
||||
const torch::Tensor joint_map,
|
||||
const torch::Tensor joint_map_type,
|
||||
const torch::Tensor store_link_map,
|
||||
const torch::Tensor link_sphere_map,
|
||||
const torch::Tensor link_chain_map,
|
||||
const int batch_size,
|
||||
const int n_spheres,
|
||||
const bool sparsity_opt = true,
|
||||
const bool use_global_cumul = false);
|
||||
|
||||
std::vector<torch::Tensor> kin_fused_backward_16t(
|
||||
torch::Tensor grad_out, const torch::Tensor grad_nlinks_pos,
|
||||
const torch::Tensor grad_nlinks_quat, const torch::Tensor grad_spheres,
|
||||
const torch::Tensor global_cumul_mat, const torch::Tensor joint_vec,
|
||||
const torch::Tensor fixed_transform, const torch::Tensor robot_spheres,
|
||||
const torch::Tensor link_map, const torch::Tensor joint_map,
|
||||
const torch::Tensor joint_map_type, const torch::Tensor store_link_map,
|
||||
const torch::Tensor link_sphere_map, const torch::Tensor link_chain_map,
|
||||
const int batch_size, const int n_spheres, const bool sparsity_opt = true,
|
||||
const bool use_global_cumul = false);
|
||||
// C++ interface
|
||||
|
||||
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), # x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), # x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<torch::Tensor> kin_forward_wrapper(
|
||||
torch::Tensor link_pos, torch::Tensor link_quat,
|
||||
torch::Tensor batch_robot_spheres, torch::Tensor global_cumul_mat,
|
||||
const torch::Tensor joint_vec, const torch::Tensor fixed_transform,
|
||||
const torch::Tensor robot_spheres, const torch::Tensor link_map,
|
||||
const torch::Tensor joint_map, const torch::Tensor joint_map_type,
|
||||
const torch::Tensor store_link_map, const torch::Tensor link_sphere_map,
|
||||
const int batch_size, const int n_spheres,
|
||||
const bool use_global_cumul = false) {
|
||||
|
||||
std::vector<torch::Tensor>kin_forward_wrapper(
|
||||
torch::Tensor link_pos, torch::Tensor link_quat,
|
||||
torch::Tensor batch_robot_spheres, torch::Tensor global_cumul_mat,
|
||||
const torch::Tensor joint_vec, const torch::Tensor fixed_transform,
|
||||
const torch::Tensor robot_spheres, const torch::Tensor link_map,
|
||||
const torch::Tensor joint_map, const torch::Tensor joint_map_type,
|
||||
const torch::Tensor store_link_map, const torch::Tensor link_sphere_map,
|
||||
const int batch_size, const int n_spheres,
|
||||
const bool use_global_cumul = false)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(joint_vec.device());
|
||||
|
||||
// TODO: add check input
|
||||
return kin_fused_forward(
|
||||
link_pos, link_quat, batch_robot_spheres, global_cumul_mat, joint_vec,
|
||||
fixed_transform, robot_spheres, link_map, joint_map, joint_map_type,
|
||||
store_link_map, link_sphere_map, batch_size, n_spheres, use_global_cumul);
|
||||
link_pos, link_quat, batch_robot_spheres, global_cumul_mat, joint_vec,
|
||||
fixed_transform, robot_spheres, link_map, joint_map, joint_map_type,
|
||||
store_link_map, link_sphere_map, batch_size, n_spheres, use_global_cumul);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> kin_backward_wrapper(
|
||||
torch::Tensor grad_out, const torch::Tensor grad_nlinks_pos,
|
||||
const torch::Tensor grad_nlinks_quat, const torch::Tensor grad_spheres,
|
||||
const torch::Tensor global_cumul_mat, const torch::Tensor joint_vec,
|
||||
const torch::Tensor fixed_transform, const torch::Tensor robot_spheres,
|
||||
const torch::Tensor link_map, const torch::Tensor joint_map,
|
||||
const torch::Tensor joint_map_type, const torch::Tensor store_link_map,
|
||||
const torch::Tensor link_sphere_map, const torch::Tensor link_chain_map,
|
||||
const int batch_size, const int n_spheres, const bool sparsity_opt = true,
|
||||
const bool use_global_cumul = false) {
|
||||
std::vector<torch::Tensor>kin_backward_wrapper(
|
||||
torch::Tensor grad_out, const torch::Tensor grad_nlinks_pos,
|
||||
const torch::Tensor grad_nlinks_quat, const torch::Tensor grad_spheres,
|
||||
const torch::Tensor global_cumul_mat, const torch::Tensor joint_vec,
|
||||
const torch::Tensor fixed_transform, const torch::Tensor robot_spheres,
|
||||
const torch::Tensor link_map, const torch::Tensor joint_map,
|
||||
const torch::Tensor joint_map_type, const torch::Tensor store_link_map,
|
||||
const torch::Tensor link_sphere_map, const torch::Tensor link_chain_map,
|
||||
const int batch_size, const int n_spheres, const bool sparsity_opt = true,
|
||||
const bool use_global_cumul = false)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(joint_vec.device());
|
||||
|
||||
return kin_fused_backward_16t(
|
||||
grad_out, grad_nlinks_pos, grad_nlinks_quat, grad_spheres,
|
||||
global_cumul_mat, joint_vec, fixed_transform, robot_spheres, link_map,
|
||||
joint_map, joint_map_type, store_link_map, link_sphere_map,
|
||||
link_chain_map, batch_size, n_spheres, sparsity_opt, use_global_cumul);
|
||||
grad_out, grad_nlinks_pos, grad_nlinks_quat, grad_spheres,
|
||||
global_cumul_mat, joint_vec, fixed_transform, robot_spheres, link_map,
|
||||
joint_map, joint_map_type, store_link_map, link_sphere_map,
|
||||
link_chain_map, batch_size, n_spheres, sparsity_opt, use_global_cumul);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
matrix_to_quaternion_wrapper(torch::Tensor out_quat,
|
||||
matrix_to_quaternion_wrapper(torch::Tensor out_quat,
|
||||
const torch::Tensor in_rot // batch_size, 3
|
||||
) {
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard guard(in_rot.device());
|
||||
|
||||
CHECK_INPUT(in_rot);
|
||||
CHECK_INPUT(out_quat);
|
||||
return matrix_to_quaternion(out_quat, in_rot);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &kin_forward_wrapper, "Kinematics fused forward (CUDA)");
|
||||
m.def("backward", &kin_backward_wrapper, "Kinematics fused backward (CUDA)");
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("forward", &kin_forward_wrapper, "Kinematics fused forward (CUDA)");
|
||||
m.def("backward", &kin_backward_wrapper, "Kinematics fused backward (CUDA)");
|
||||
m.def("matrix_to_quaternion", &matrix_to_quaternion_wrapper,
|
||||
"Rotation Matrix to Quaternion (CUDA)");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user