Add pusht on hf dataset (WIP)
This commit is contained in:
@@ -8,6 +8,7 @@ import pymunk
|
|||||||
import torch
|
import torch
|
||||||
import torchrl
|
import torchrl
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||||
@@ -112,6 +113,10 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _download_and_preproc(self):
|
def _download_and_preproc(self):
|
||||||
|
snapshot_download(repo_id="cadene/pusht", local_dir=self.data_dir)
|
||||||
|
return TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||||
|
|
||||||
|
def _download_and_preproc_obsolete(self):
|
||||||
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||||
if not zarr_path.is_dir():
|
if not zarr_path.is_dir():
|
||||||
|
|||||||
2
poetry.lock
generated
2
poetry.lock
generated
@@ -3254,4 +3254,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "3d82309a7b2388d774b56ceb6f6906ef0732d8cedda0d76cc84a30e239949be8"
|
content-hash = "0794a87fd309dffa0ad2982b6902bed7f35ae9e2a82433420516798da04c7197"
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ diffusers = "^0.26.3"
|
|||||||
torchvision = "^0.17.1"
|
torchvision = "^0.17.1"
|
||||||
h5py = "^3.10.0"
|
h5py = "^3.10.0"
|
||||||
dm-control = "1.0.14"
|
dm-control = "1.0.14"
|
||||||
|
huggingface-hub = "^0.21.4"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|||||||
Reference in New Issue
Block a user