Merge remote-tracking branch 'origin/main' into user/alexander-soare/multistep_policy_and_serial_env

This commit is contained in:
Alexander Soare
2024-03-15 13:05:35 +00:00
10 changed files with 144 additions and 54 deletions

View File

@@ -1,4 +1,3 @@
import abc
import logging
from pathlib import Path
from typing import Callable
@@ -7,8 +6,8 @@ import einops
import torch
import torchrl
import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
@@ -23,7 +22,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
@@ -33,11 +32,8 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
):
self.dataset_id = dataset_id
self.shuffle = shuffle
self.root = _get_root_dir(self.dataset_id) if root is None else root
self.root = Path(self.root)
self.data_dir = self.root / self.dataset_id
storage = self._download_or_load_storage()
self.root = root
storage = self._download_or_load_dataset()
super().__init__(
storage=storage,
@@ -98,19 +94,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
torch.save(stats, stats_path)
return stats
@abc.abstractmethod
def _download_and_preproc(self) -> torch.StorageBase:
raise NotImplementedError()
def _download_or_load_storage(self):
if not self._is_downloaded():
storage = self._download_and_preproc()
def _download_or_load_dataset(self) -> torch.StorageBase:
if self.root is None:
self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
else:
storage = TensorStorage(TensorDict.load_memmap(self.data_dir))
return storage
def _is_downloaded(self) -> bool:
return self.data_dir.is_dir()
self.data_dir = self.root / self.dataset_id
return TensorStorage(TensorDict.load_memmap(self.data_dir))
def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(

View File

@@ -87,7 +87,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
@@ -124,8 +124,9 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
def _download_and_preproc_obsolete(self):
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
if not raw_dir.is_dir():
download(raw_dir, self.dataset_id)
@@ -174,7 +175,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
if ep_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir)
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
td_data[idxtd : idxtd + len(ep_td)] = ep_td
idxtd = idxtd + len(ep_td)

View File

@@ -7,7 +7,10 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from lerobot.common.envs.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
# to load a subset of our datasets for faster continuous integration.
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_offline_buffer(
@@ -77,9 +80,9 @@ def make_offline_buffer(
offline_buffer = clsfunc(
dataset_id=dataset_id,
root=DATA_DIR,
sampler=sampler,
batch_size=batch_size,
root=DATA_DIR,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
)

View File

@@ -90,7 +90,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
@@ -111,8 +111,9 @@ class PushtExperienceReplay(AbstractExperienceReplay):
transform=transform,
)
def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
def _download_and_preproc_obsolete(self):
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
@@ -208,7 +209,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
td_data[idxtd : idxtd + len(ep_td)] = ep_td

View File

@@ -43,7 +43,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
@@ -64,11 +64,12 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
transform=transform,
)
def _download_and_preproc(self):
def _download_and_preproc_obsolete(self):
assert self.root is not None
# TODO(rcadene): finish download
download()
dataset_path = self.data_dir / "buffer.pkl"
dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
@@ -110,7 +111,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
td_data[idx0:idx1] = episode