Add planning to grasp API

This commit is contained in:
Balakumar Sundaralingam
2024-11-22 14:15:18 -08:00
parent 18e9ebd35f
commit 36ea382dab
38 changed files with 939 additions and 535 deletions

View File

@@ -53,6 +53,7 @@ def test_repeat_seeds():
weight = tensor_args.to_device([1, 1, 1, 1])
vec_convergence = tensor_args.to_device([0, 0])
run_weight = tensor_args.to_device([1])
project_distance = torch.tensor([True], device=tensor_args.device, dtype=torch.uint8)
r = get_pose_distance(
out_d,
out_d.clone(),
@@ -72,6 +73,7 @@ def test_repeat_seeds():
offset_waypoint,
offset_tstep_fraction,
g.batch_pose_idx,
project_distance,
start_pose.position.shape[0],
1,
1,
@@ -79,7 +81,6 @@ def test_repeat_seeds():
False,
False,
True,
True,
)
assert torch.sum(r[0]).item() <= 1e-5
@@ -105,6 +106,8 @@ def test_horizon_repeat_seeds():
batch_pose_idx = torch.arange(0, b, 1, device=tensor_args.device, dtype=torch.int32).unsqueeze(
-1
)
project_distance = torch.tensor([True], device=tensor_args.device, dtype=torch.uint8)
goal = Goal(goal_pose=goal_pose, batch_pose_idx=batch_pose_idx, current_state=current_state)
g = goal # .repeat_seeds(4)
@@ -144,6 +147,7 @@ def test_horizon_repeat_seeds():
offset_waypoint,
offset_tstep_fraction,
g.batch_pose_idx,
project_distance,
start_pose.position.shape[0],
h,
1,
@@ -151,6 +155,5 @@ def test_horizon_repeat_seeds():
True,
False,
False,
True,
)
assert torch.sum(r[0]).item() < 1e-5