Add replay_buffer directory in pusht datasets + aloha (WIP)

This commit is contained in:
Cadene
2024-03-19 15:49:45 +00:00
parent 099a465367
commit 6a1a29386a
20 changed files with 53 additions and 8 deletions

View File

@@ -19,6 +19,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def __init__(
self,
dataset_id: str,
version: str | None = None,
batch_size: int = None,
*,
shuffle: bool = True,
@@ -31,6 +32,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
transform: "torchrl.envs.Transform" = None,
):
self.dataset_id = dataset_id
self.version = version
self.shuffle = shuffle
self.root = root
storage = self._download_or_load_dataset()
@@ -96,10 +98,14 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def _download_or_load_dataset(self) -> torch.StorageBase:
if self.root is None:
self.data_dir = Path(snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset"))
self.data_dir = Path(
snapshot_download(
repo_id=f"cadene/{self.dataset_id}", repo_type="dataset", revision=self.version
)
)
else:
self.data_dir = self.root / self.dataset_id
return TensorStorage(TensorDict.load_memmap(self.data_dir))
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(