diff --git a/tests/test_examples.py b/tests/test_examples.py index 24d26400..b8505790 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -112,8 +112,8 @@ def test_examples_basic2_basic3_advanced1(): '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', ), - ("train_episodes = episodes[:num_train_episodes]", 'train_episodes = [0]"'), - ("val_episodes = episodes[num_train_episodes:]", 'val_episodes = [1]"'), + ("train_episodes = episodes[:num_train_episodes]", "train_episodes = [0]"), + ("val_episodes = episodes[num_train_episodes:]", "val_episodes = [1]"), ("num_workers=4", "num_workers=0"), ('device = torch.device("cuda")', 'device = torch.device("cpu")'), ("batch_size=64", "batch_size=1"),