Fixes for PR #4

This commit is contained in:
Simon Alibert
2024-03-01 14:59:05 +01:00
parent b862145e22
commit c1942d45d3
4 changed files with 38 additions and 39 deletions

View File

@@ -18,7 +18,7 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from lerobot.common import utils
from lerobot.common.datasets import utils
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
@@ -97,17 +97,15 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
replacement: bool = None,
streaming: bool = False,
root: Path = None,
download: bool | str = False,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
pin_memory: bool = False,
prefetch: int = None,
transform: "torchrl.envs.Transform" = None, # noqa-F821
transform: "torchrl.envs.Transform" = None, # noqa: F821
split_trajs: bool = False,
strict_length: bool = True,
):
self.download = download
if streaming:
raise NotImplementedError
self.streaming = streaming
@@ -129,8 +127,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
os.makedirs(root, exist_ok=True)
self.root = root
self.raw = self.root / "raw"
if self.download == "force" or (self.download and not self._is_downloaded()):
if not self._is_downloaded():
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
@@ -192,9 +189,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
def _download_and_preproc(self):
# download
self.raw.mkdir(exist_ok=True)
utils.download_and_extract_zip(PUSHT_URL, self.raw)
zarr_path = (self.raw / PUSHT_ZARR).resolve()
raw_dir = self.root / "raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
utils.download_and_extract_zip(PUSHT_URL, raw_dir)
# load
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])