diff --git a/.github/poetry/cpu/poetry.lock b/.github/poetry/cpu/poetry.lock index d224b6682..c07e34395 100644 --- a/.github/poetry/cpu/poetry.lock +++ b/.github/poetry/cpu/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -3123,4 +3123,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "66c60543d2f59ac3d0e1fcda298ea14c0c60a8c6bcea73902f4f6aa3dd47661b" +content-hash = "4aa6a1e3f29560dd4a1c24d493ee1154089da4aa8d2190ad1f786c125ab2b735" diff --git a/.github/poetry/cpu/pyproject.toml b/.github/poetry/cpu/pyproject.toml index 4880f61e2..fd7eb226a 100644 --- a/.github/poetry/cpu/pyproject.toml +++ b/.github/poetry/cpu/pyproject.toml @@ -51,6 +51,7 @@ torchvision = {version = "^0.17.1", source = "torch-cpu"} h5py = "^3.10.0" dm = "^1.3" dm-control = "^1.0.16" +huggingface-hub = "^0.21.4" [tool.poetry.group.dev.dependencies] diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 61a0d25bc..3e0e2c320 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -32,7 +32,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): ): self.dataset_id = dataset_id self.shuffle = shuffle - self.root = root + self.root = root if root is None else Path(root) storage = self._download_or_load_dataset() super().__init__( @@ -98,7 +98,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): if self.root is None: self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset") else: - self.data_dir = Path(self.root) / self.dataset_id + self.data_dir = self.root / self.dataset_id return TensorStorage(TensorDict.load_memmap(self.data_dir)) def _compute_stats(self, num_batch=100, batch_size=32): diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 68a3aa82d..2ea4b831a 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -124,8 +124,9 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def image_keys(self) -> list: return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] - def _download_and_preproc_obsolete(self, data_dir="data"): - raw_dir = Path(data_dir) / f"{self.dataset_id}_raw" + def _download_and_preproc_obsolete(self): + assert self.root is not None + raw_dir = self.root / f"{self.dataset_id}_raw" if not raw_dir.is_dir(): download(raw_dir, self.dataset_id) @@ -174,9 +175,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): if ep_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = ( - ep_td[0].expand(total_num_frames).memmap_like(Path(self.root) / f"{self.dataset_id}") - ) + td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}") td_data[idxtd : idxtd + len(ep_td)] = ep_td idxtd = idxtd + len(ep_td) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index ed2ec4eed..bac742d99 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -112,7 +112,8 @@ class PushtExperienceReplay(AbstractExperienceReplay): ) def _download_and_preproc_obsolete(self): - raw_dir = Path(self.root) / f"{self.dataset_id}_raw" + assert self.root is not None + raw_dir = self.root / f"{self.dataset_id}_raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() if not zarr_path.is_dir(): raw_dir.mkdir(parents=True, exist_ok=True) @@ -208,7 +209,7 @@ class PushtExperienceReplay(AbstractExperienceReplay): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}") + td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}") td_data[idxtd : idxtd + len(ep_td)] = ep_td diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 1d620c358..b4dd824f9 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -65,10 +65,11 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): ) def _download_and_preproc_obsolete(self): + assert self.root is not None # TODO(rcadene): finish download download() - dataset_path = Path(self.root) / "data" / "buffer.pkl" + dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") with open(dataset_path, "rb") as f: dataset_dict = pickle.load(f) @@ -110,7 +111,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = episode[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}") + td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}") td_data[idx0:idx1] = episode