From b69a1327372aab1064f5ce7b0f3b3884268f823b Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 3 Nov 2024 19:30:56 +0100 Subject: [PATCH] Fix test_examples --- tests/test_examples.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index 6b304863..24d26400 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -112,7 +112,8 @@ def test_examples_basic2_basic3_advanced1(): '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', ), - ('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'), + ("train_episodes = episodes[:num_train_episodes]", 'train_episodes = [0]"'), + ("val_episodes = episodes[num_train_episodes:]", 'val_episodes = [1]"'), ("num_workers=4", "num_workers=0"), ('device = torch.device("cuda")', 'device = torch.device("cpu")'), ("batch_size=64", "batch_size=1"),