Add planning to grasp API
This commit is contained in:
@@ -53,6 +53,7 @@ def test_repeat_seeds():
|
||||
weight = tensor_args.to_device([1, 1, 1, 1])
|
||||
vec_convergence = tensor_args.to_device([0, 0])
|
||||
run_weight = tensor_args.to_device([1])
|
||||
project_distance = torch.tensor([True], device=tensor_args.device, dtype=torch.uint8)
|
||||
r = get_pose_distance(
|
||||
out_d,
|
||||
out_d.clone(),
|
||||
@@ -72,6 +73,7 @@ def test_repeat_seeds():
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
g.batch_pose_idx,
|
||||
project_distance,
|
||||
start_pose.position.shape[0],
|
||||
1,
|
||||
1,
|
||||
@@ -79,7 +81,6 @@ def test_repeat_seeds():
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
)
|
||||
|
||||
assert torch.sum(r[0]).item() <= 1e-5
|
||||
@@ -105,6 +106,8 @@ def test_horizon_repeat_seeds():
|
||||
batch_pose_idx = torch.arange(0, b, 1, device=tensor_args.device, dtype=torch.int32).unsqueeze(
|
||||
-1
|
||||
)
|
||||
project_distance = torch.tensor([True], device=tensor_args.device, dtype=torch.uint8)
|
||||
|
||||
goal = Goal(goal_pose=goal_pose, batch_pose_idx=batch_pose_idx, current_state=current_state)
|
||||
g = goal # .repeat_seeds(4)
|
||||
|
||||
@@ -144,6 +147,7 @@ def test_horizon_repeat_seeds():
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
g.batch_pose_idx,
|
||||
project_distance,
|
||||
start_pose.position.shape[0],
|
||||
h,
|
||||
1,
|
||||
@@ -151,6 +155,5 @@ def test_horizon_repeat_seeds():
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
assert torch.sum(r[0]).item() < 1e-5
|
||||
|
||||
Reference in New Issue
Block a user