Add pusht on hf dataset (WIP)
This commit is contained in:
@@ -8,6 +8,7 @@ import pymunk
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
@@ -112,6 +113,10 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
snapshot_download(repo_id="cadene/pusht", local_dir=self.data_dir)
|
||||
return TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
|
||||
Reference in New Issue
Block a user