update pose inverse for goalset
This commit is contained in:
@@ -941,7 +941,6 @@ class PoseInverse(torch.autograd.Function):
|
|||||||
adj_position: torch.Tensor,
|
adj_position: torch.Tensor,
|
||||||
adj_quaternion: torch.Tensor,
|
adj_quaternion: torch.Tensor,
|
||||||
):
|
):
|
||||||
b, _ = position.shape
|
|
||||||
|
|
||||||
if out_position is None:
|
if out_position is None:
|
||||||
out_position = torch.zeros_like(position)
|
out_position = torch.zeros_like(position)
|
||||||
@@ -951,7 +950,8 @@ class PoseInverse(torch.autograd.Function):
|
|||||||
adj_position = torch.zeros_like(position)
|
adj_position = torch.zeros_like(position)
|
||||||
if adj_quaternion is None:
|
if adj_quaternion is None:
|
||||||
adj_quaternion = torch.zeros_like(quaternion)
|
adj_quaternion = torch.zeros_like(quaternion)
|
||||||
|
b, _ = position.view(-1, 3).shape
|
||||||
|
ctx.b = b
|
||||||
init_warp()
|
init_warp()
|
||||||
ctx.save_for_backward(
|
ctx.save_for_backward(
|
||||||
position,
|
position,
|
||||||
@@ -961,7 +961,6 @@ class PoseInverse(torch.autograd.Function):
|
|||||||
adj_position,
|
adj_position,
|
||||||
adj_quaternion,
|
adj_quaternion,
|
||||||
)
|
)
|
||||||
ctx.b = b
|
|
||||||
|
|
||||||
wp.launch(
|
wp.launch(
|
||||||
kernel=compute_pose_inverse,
|
kernel=compute_pose_inverse,
|
||||||
@@ -976,9 +975,6 @@ class PoseInverse(torch.autograd.Function):
|
|||||||
],
|
],
|
||||||
stream=wp.stream_from_torch(position.device),
|
stream=wp.stream_from_torch(position.device),
|
||||||
)
|
)
|
||||||
# remove close to zero values:
|
|
||||||
# out_position[torch.abs(out_position)<1e-8] = 0.0
|
|
||||||
# out_quaternion[torch.abs(out_quaternion)<1e-8] = 0.0
|
|
||||||
|
|
||||||
return out_position, out_quaternion
|
return out_position, out_quaternion
|
||||||
|
|
||||||
|
|||||||
@@ -365,7 +365,9 @@ class Pose(Sequence):
|
|||||||
|
|
||||||
@profiler.record_function("pose/multiply")
|
@profiler.record_function("pose/multiply")
|
||||||
def multiply(self, other_pose: Pose):
|
def multiply(self, other_pose: Pose):
|
||||||
if self.shape == other_pose.shape or (self.shape[0] == 1 and other_pose.shape[0] > 1):
|
if self.shape == other_pose.shape or (
|
||||||
|
(self.shape[0] == 1 and other_pose.shape[0] > 1) and len(other_pose.shape) == 2
|
||||||
|
):
|
||||||
p3, q3 = pose_multiply(
|
p3, q3 = pose_multiply(
|
||||||
self.position, self.quaternion, other_pose.position, other_pose.quaternion
|
self.position, self.quaternion, other_pose.position, other_pose.quaternion
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,23 +30,29 @@ def motion_gen(request):
|
|||||||
tensor_args = TensorDeviceType()
|
tensor_args = TensorDeviceType()
|
||||||
world_file = "collision_table.yml"
|
world_file = "collision_table.yml"
|
||||||
robot_file = "franka.yml"
|
robot_file = "franka.yml"
|
||||||
|
|
||||||
motion_gen_config = MotionGenConfig.load_from_robot_config(
|
motion_gen_config = MotionGenConfig.load_from_robot_config(
|
||||||
robot_file,
|
robot_file,
|
||||||
world_file,
|
world_file,
|
||||||
tensor_args,
|
tensor_args,
|
||||||
use_cuda_graph=True,
|
use_cuda_graph=True,
|
||||||
project_pose_to_goal_frame=request.param,
|
project_pose_to_goal_frame=request.param[0],
|
||||||
)
|
)
|
||||||
motion_gen_instance = MotionGen(motion_gen_config)
|
motion_gen_instance = MotionGen(motion_gen_config)
|
||||||
motion_gen_instance.warmup(enable_graph=False, warmup_js_trajopt=False)
|
|
||||||
|
motion_gen_instance.warmup(
|
||||||
|
enable_graph=False, warmup_js_trajopt=False, n_goalset=request.param[1]
|
||||||
|
)
|
||||||
return motion_gen_instance
|
return motion_gen_instance
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"motion_gen",
|
"motion_gen",
|
||||||
[
|
[
|
||||||
(True),
|
([True, -1]),
|
||||||
(False),
|
([False, -1]),
|
||||||
|
([True, 10]),
|
||||||
|
([False, 10]),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -77,8 +83,10 @@ def test_approach_grasp_pose(motion_gen):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"motion_gen",
|
"motion_gen",
|
||||||
[
|
[
|
||||||
(True),
|
([True, -1]),
|
||||||
(False),
|
([False, -1]),
|
||||||
|
([True, 10]),
|
||||||
|
([False, 10]),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -112,8 +120,10 @@ def test_reach_only_position(motion_gen):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"motion_gen",
|
"motion_gen",
|
||||||
[
|
[
|
||||||
(True),
|
([True, -1]),
|
||||||
(False),
|
([False, -1]),
|
||||||
|
([True, 10]),
|
||||||
|
([False, 10]),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -147,8 +157,10 @@ def test_reach_only_orientation(motion_gen):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"motion_gen",
|
"motion_gen",
|
||||||
[
|
[
|
||||||
(True),
|
([True, -1]),
|
||||||
(False),
|
([False, -1]),
|
||||||
|
([True, 10]),
|
||||||
|
([False, 10]),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -186,8 +198,10 @@ def test_hold_orientation(motion_gen):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"motion_gen",
|
"motion_gen",
|
||||||
[
|
[
|
||||||
(True),
|
([True, -1]),
|
||||||
(False),
|
([False, -1]),
|
||||||
|
([True, 10]),
|
||||||
|
([False, 10]),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -224,8 +238,10 @@ def test_hold_position(motion_gen):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"motion_gen",
|
"motion_gen",
|
||||||
[
|
[
|
||||||
(False),
|
([True, -1]),
|
||||||
(True),
|
([False, -1]),
|
||||||
|
([True, 10]),
|
||||||
|
([False, 10]),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@@ -276,8 +292,10 @@ def test_hold_partial_pose(motion_gen):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"motion_gen",
|
"motion_gen",
|
||||||
[
|
[
|
||||||
(False),
|
([True, -1]),
|
||||||
(True),
|
([False, -1]),
|
||||||
|
([True, 10]),
|
||||||
|
([False, 10]),
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user