Add aloha + improve readme

This commit is contained in:
Cadene
2024-03-15 00:30:11 +00:00
parent 19730b3412
commit a311d38796
8 changed files with 115 additions and 37 deletions

View File

@@ -1,4 +1,3 @@
import abc
import logging
from pathlib import Path
from typing import Callable
@@ -7,8 +6,8 @@ import einops
import torch
import torchrl
import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
@@ -33,11 +32,8 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
):
self.dataset_id = dataset_id
self.shuffle = shuffle
self.root = _get_root_dir(self.dataset_id) if root is None else root
self.root = Path(self.root)
self.data_dir = self.root / self.dataset_id
storage = self._download_or_load_storage()
self.root = root
storage = self._download_or_load_dataset()
super().__init__(
storage=storage,
@@ -98,19 +94,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
torch.save(stats, stats_path)
return stats
@abc.abstractmethod
def _download_and_preproc(self) -> torch.StorageBase:
raise NotImplementedError()
def _download_or_load_storage(self):
if not self._is_downloaded():
storage = self._download_and_preproc()
def _download_or_load_dataset(self) -> torch.StorageBase:
if self.root is None:
data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
else:
storage = TensorStorage(TensorDict.load_memmap(self.data_dir))
return storage
def _is_downloaded(self) -> bool:
return self.data_dir.is_dir()
data_dir = Path(self.root) / self.dataset_id
return TensorStorage(TensorDict.load_memmap(data_dir))
def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(