diff --git a/tests/test_examples.py b/tests/test_examples.py index 6b304863..24d26400 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -112,7 +112,8 @@ def test_examples_basic2_basic3_advanced1(): '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', ), - ('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'), + ("train_episodes = episodes[:num_train_episodes]", 'train_episodes = [0]"'), + ("val_episodes = episodes[num_train_episodes:]", 'val_episodes = [1]"'), ("num_workers=4", "num_workers=0"), ('device = torch.device("cuda")', 'device = torch.device("cpu")'), ("batch_size=64", "batch_size=1"),