diff --git a/tests/goal_test.py b/tests/goal_test.py index 740b9c9..b95ac1b 100644 --- a/tests/goal_test.py +++ b/tests/goal_test.py @@ -82,7 +82,7 @@ def test_repeat_seeds(): True, ) - assert torch.sum(r[0]).item() == 0.0 + assert torch.sum(r[0]).item() <= 1e-5 def test_horizon_repeat_seeds(): @@ -153,4 +153,4 @@ def test_horizon_repeat_seeds(): False, True, ) - assert torch.sum(r[0]).item() < 1e-6 + assert torch.sum(r[0]).item() < 1e-5