diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index d3015f3..a81de49 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -70,7 +70,7 @@ class AbstractDataset(TensorDictReplayBuffer): self.dataset_id = dataset_id self.version = version self.shuffle = shuffle - self.root = root + self.root = root if root is None else Path(root) if self.root is not None and self.version is not None: logging.warning( diff --git a/tests/test_examples.py b/tests/test_examples.py index b1d3a48..6c21eb4 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,7 +1,6 @@ import pytest from pathlib import Path -@pytest.mark.skip(reason="For some reason 1_visualize_dataset.py downloads the dataset") @pytest.mark.parametrize( "path", [ @@ -16,4 +15,5 @@ def test_example(path): file_contents = file.read() exec(file_contents) - assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists() + if path == "examples/1_visualize_dataset.py": + assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()