Add aloha + improve readme

This commit is contained in:
Cadene
2024-03-15 00:30:11 +00:00
parent 19730b3412
commit a311d38796
8 changed files with 115 additions and 37 deletions

View File

@@ -8,7 +8,6 @@ import pymunk
import torch
import torchrl
import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage
@@ -112,12 +111,8 @@ class PushtExperienceReplay(AbstractExperienceReplay):
transform=transform,
)
def _download_and_preproc(self):
snapshot_download(repo_id="cadene/pusht", local_dir=self.data_dir)
return TensorStorage(TensorDict.load_memmap(self.data_dir))
def _download_and_preproc_obsolete(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
raw_dir = Path(self.root) / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
@@ -213,7 +208,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
td_data = ep_td[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}")
td_data[idxtd : idxtd + len(ep_td)] = ep_td