Add aloha + improve readme

This commit is contained in:
Cadene
2024-03-15 00:30:11 +00:00
parent 19730b3412
commit a311d38796
8 changed files with 115 additions and 37 deletions

View File

@@ -1,4 +1,3 @@
import abc
import logging
from pathlib import Path
from typing import Callable
@@ -7,8 +6,8 @@ import einops
import torch
import torchrl
import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
@@ -33,11 +32,8 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
):
self.dataset_id = dataset_id
self.shuffle = shuffle
self.root = _get_root_dir(self.dataset_id) if root is None else root
self.root = Path(self.root)
self.data_dir = self.root / self.dataset_id
storage = self._download_or_load_storage()
self.root = root
storage = self._download_or_load_dataset()
super().__init__(
storage=storage,
@@ -98,19 +94,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
torch.save(stats, stats_path)
return stats
@abc.abstractmethod
def _download_and_preproc(self) -> torch.StorageBase:
raise NotImplementedError()
def _download_or_load_storage(self):
if not self._is_downloaded():
storage = self._download_and_preproc()
def _download_or_load_dataset(self) -> torch.StorageBase:
if self.root is None:
data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
else:
storage = TensorStorage(TensorDict.load_memmap(self.data_dir))
return storage
def _is_downloaded(self) -> bool:
return self.data_dir.is_dir()
data_dir = Path(self.root) / self.dataset_id
return TensorStorage(TensorDict.load_memmap(data_dir))
def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(

View File

@@ -124,8 +124,8 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
def _download_and_preproc_obsolete(self, data_dir="data"):
raw_dir = Path(data_dir) / f"{self.dataset_id}_raw"
if not raw_dir.is_dir():
download(raw_dir, self.dataset_id)
@@ -174,7 +174,9 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
if ep_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir)
td_data = (
ep_td[0].expand(total_num_frames).memmap_like(Path(self.root) / f"{self.dataset_id}")
)
td_data[idxtd : idxtd + len(ep_td)] = ep_td
idxtd = idxtd + len(ep_td)

View File

@@ -1,13 +1,13 @@
import logging
import os
from pathlib import Path
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from lerobot.common.envs.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
# used for unit tests
DATA_DIR = os.environ.get("DATA_DIR", None)
def make_offline_buffer(
@@ -77,9 +77,9 @@ def make_offline_buffer(
offline_buffer = clsfunc(
dataset_id=dataset_id,
root=DATA_DIR,
sampler=sampler,
batch_size=batch_size,
root=DATA_DIR,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
)

View File

@@ -8,7 +8,6 @@ import pymunk
import torch
import torchrl
import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage
@@ -112,12 +111,8 @@ class PushtExperienceReplay(AbstractExperienceReplay):
transform=transform,
)
def _download_and_preproc(self):
snapshot_download(repo_id="cadene/pusht", local_dir=self.data_dir)
return TensorStorage(TensorDict.load_memmap(self.data_dir))
def _download_and_preproc_obsolete(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
raw_dir = Path(self.root) / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
@@ -213,7 +208,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
td_data = ep_td[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}")
td_data[idxtd : idxtd + len(ep_td)] = ep_td

View File

@@ -64,11 +64,11 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
transform=transform,
)
def _download_and_preproc(self):
def _download_and_preproc_obsolete(self):
# TODO(rcadene): finish download
download()
dataset_path = self.data_dir / "buffer.pkl"
dataset_path = Path(self.root) / "data" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
@@ -110,7 +110,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
td_data = episode[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}")
td_data[idx0:idx1] = episode