constrained planning, robot segmentation

This commit is contained in:
Balakumar Sundaralingam
2024-02-22 21:45:47 -08:00
parent 88eac64edc
commit bafdf80c05
102 changed files with 12440 additions and 8112 deletions

View File

@@ -47,6 +47,9 @@ def test_repeat_seeds():
out_r_v = torch.zeros((b, 4), device=tensor_args.device, dtype=tensor_args.dtype)
out_idx = torch.zeros((b, 1), device=tensor_args.device, dtype=torch.int32)
vec_weight = torch.ones((6), device=tensor_args.device, dtype=tensor_args.dtype)
offset_waypoint = torch.zeros((6), device=tensor_args.device, dtype=tensor_args.dtype)
offset_tstep_fraction = torch.ones((1), device=tensor_args.device, dtype=tensor_args.dtype)
weight = tensor_args.to_device([1, 1, 1, 1])
vec_convergence = tensor_args.to_device([0, 0])
run_weight = tensor_args.to_device([1])
@@ -66,6 +69,8 @@ def test_repeat_seeds():
vec_convergence,
run_weight,
vec_weight.clone() * 0.0,
offset_waypoint,
offset_tstep_fraction,
g.batch_pose_idx,
start_pose.position.shape[0],
1,
@@ -74,6 +79,7 @@ def test_repeat_seeds():
False,
False,
True,
True,
)
assert torch.sum(r[0]).item() == 0.0
@@ -112,6 +118,9 @@ def test_horizon_repeat_seeds():
out_r_v = torch.zeros((b, h, 4), device=tensor_args.device, dtype=tensor_args.dtype)
out_idx = torch.zeros((b, h, 1), device=tensor_args.device, dtype=torch.int32)
vec_weight = torch.ones((6), device=tensor_args.device, dtype=tensor_args.dtype)
offset_waypoint = torch.zeros((6), device=tensor_args.device, dtype=tensor_args.dtype)
offset_tstep_fraction = torch.ones((1), device=tensor_args.device, dtype=tensor_args.dtype)
weight = tensor_args.to_device([1, 1, 1, 1])
vec_convergence = tensor_args.to_device([0, 0])
run_weight = torch.zeros((1, h), device=tensor_args.device)
@@ -132,6 +141,8 @@ def test_horizon_repeat_seeds():
vec_convergence,
run_weight,
vec_weight.clone() * 0.0,
offset_waypoint,
offset_tstep_fraction,
g.batch_pose_idx,
start_pose.position.shape[0],
h,
@@ -140,5 +151,6 @@ def test_horizon_repeat_seeds():
True,
False,
False,
True,
)
assert torch.sum(r[0]).item() == 0.0
assert torch.sum(r[0]).item() < 1e-6