self.root is Path or None + The following packages are already present in the pyproject.toml and will be skipped:
- huggingface-hub If you want to update it to the latest compatible version, you can use `poetry update package`. If you prefer to upgrade it to the latest available version, you can use `poetry add package@latest`. Nothing to add.
This commit is contained in:
4
.github/poetry/cpu/poetry.lock
generated
vendored
4
.github/poetry/cpu/poetry.lock
generated
vendored
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
@@ -3123,4 +3123,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "66c60543d2f59ac3d0e1fcda298ea14c0c60a8c6bcea73902f4f6aa3dd47661b"
|
||||
content-hash = "4aa6a1e3f29560dd4a1c24d493ee1154089da4aa8d2190ad1f786c125ab2b735"
|
||||
|
||||
1
.github/poetry/cpu/pyproject.toml
vendored
1
.github/poetry/cpu/pyproject.toml
vendored
@@ -51,6 +51,7 @@ torchvision = {version = "^0.17.1", source = "torch-cpu"}
|
||||
h5py = "^3.10.0"
|
||||
dm = "^1.3"
|
||||
dm-control = "^1.0.16"
|
||||
huggingface-hub = "^0.21.4"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -32,7 +32,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
):
|
||||
self.dataset_id = dataset_id
|
||||
self.shuffle = shuffle
|
||||
self.root = root
|
||||
self.root = root if root is None else Path(root)
|
||||
storage = self._download_or_load_dataset()
|
||||
|
||||
super().__init__(
|
||||
@@ -98,7 +98,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
if self.root is None:
|
||||
self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
|
||||
else:
|
||||
self.data_dir = Path(self.root) / self.dataset_id
|
||||
self.data_dir = self.root / self.dataset_id
|
||||
return TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||
|
||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||
|
||||
@@ -124,8 +124,9 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||
|
||||
def _download_and_preproc_obsolete(self, data_dir="data"):
|
||||
raw_dir = Path(data_dir) / f"{self.dataset_id}_raw"
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
download(raw_dir, self.dataset_id)
|
||||
|
||||
@@ -174,9 +175,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
|
||||
if ep_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = (
|
||||
ep_td[0].expand(total_num_frames).memmap_like(Path(self.root) / f"{self.dataset_id}")
|
||||
)
|
||||
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
@@ -112,7 +112,8 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
)
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
raw_dir = Path(self.root) / f"{self.dataset_id}_raw"
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -208,7 +209,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}")
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
|
||||
|
||||
@@ -65,10 +65,11 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
)
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
# TODO(rcadene): finish download
|
||||
download()
|
||||
|
||||
dataset_path = Path(self.root) / "data" / "buffer.pkl"
|
||||
dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
@@ -110,7 +111,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}")
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
|
||||
Reference in New Issue
Block a user