constrained planning, robot segmentation
This commit is contained in:
@@ -47,6 +47,9 @@ def test_repeat_seeds():
|
||||
out_r_v = torch.zeros((b, 4), device=tensor_args.device, dtype=tensor_args.dtype)
|
||||
out_idx = torch.zeros((b, 1), device=tensor_args.device, dtype=torch.int32)
|
||||
vec_weight = torch.ones((6), device=tensor_args.device, dtype=tensor_args.dtype)
|
||||
offset_waypoint = torch.zeros((6), device=tensor_args.device, dtype=tensor_args.dtype)
|
||||
offset_tstep_fraction = torch.ones((1), device=tensor_args.device, dtype=tensor_args.dtype)
|
||||
|
||||
weight = tensor_args.to_device([1, 1, 1, 1])
|
||||
vec_convergence = tensor_args.to_device([0, 0])
|
||||
run_weight = tensor_args.to_device([1])
|
||||
@@ -66,6 +69,8 @@ def test_repeat_seeds():
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
vec_weight.clone() * 0.0,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
g.batch_pose_idx,
|
||||
start_pose.position.shape[0],
|
||||
1,
|
||||
@@ -74,6 +79,7 @@ def test_repeat_seeds():
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
)
|
||||
|
||||
assert torch.sum(r[0]).item() == 0.0
|
||||
@@ -112,6 +118,9 @@ def test_horizon_repeat_seeds():
|
||||
out_r_v = torch.zeros((b, h, 4), device=tensor_args.device, dtype=tensor_args.dtype)
|
||||
out_idx = torch.zeros((b, h, 1), device=tensor_args.device, dtype=torch.int32)
|
||||
vec_weight = torch.ones((6), device=tensor_args.device, dtype=tensor_args.dtype)
|
||||
offset_waypoint = torch.zeros((6), device=tensor_args.device, dtype=tensor_args.dtype)
|
||||
offset_tstep_fraction = torch.ones((1), device=tensor_args.device, dtype=tensor_args.dtype)
|
||||
|
||||
weight = tensor_args.to_device([1, 1, 1, 1])
|
||||
vec_convergence = tensor_args.to_device([0, 0])
|
||||
run_weight = torch.zeros((1, h), device=tensor_args.device)
|
||||
@@ -132,6 +141,8 @@ def test_horizon_repeat_seeds():
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
vec_weight.clone() * 0.0,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
g.batch_pose_idx,
|
||||
start_pose.position.shape[0],
|
||||
h,
|
||||
@@ -140,5 +151,6 @@ def test_horizon_repeat_seeds():
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
assert torch.sum(r[0]).item() == 0.0
|
||||
assert torch.sum(r[0]).item() < 1e-6
|
||||
|
||||
Reference in New Issue
Block a user