Merge remote-tracking branch 'upstream/main' into user/alexander-soare/multistep_policy_and_serial_env

This commit is contained in:
Alexander Soare
2024-03-20 09:01:45 +00:00
76 changed files with 92 additions and 14 deletions

View File

@@ -19,6 +19,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def __init__(
self,
dataset_id: str,
version: str | None = None,
batch_size: int = None,
*,
shuffle: bool = True,
@@ -31,8 +32,15 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
transform: "torchrl.envs.Transform" = None,
):
self.dataset_id = dataset_id
self.version = version
self.shuffle = shuffle
self.root = root
if self.root is not None and self.version is not None:
logging.warning(
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
)
storage = self._download_or_load_dataset()
super().__init__(
@@ -96,10 +104,14 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def _download_or_load_dataset(self) -> torch.StorageBase:
if self.root is None:
self.data_dir = Path(snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset"))
self.data_dir = Path(
snapshot_download(
repo_id=f"cadene/{self.dataset_id}", repo_type="dataset", revision=self.version
)
)
else:
self.data_dir = self.root / self.dataset_id
return TensorStorage(TensorDict.load_memmap(self.data_dir))
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(

View File

@@ -84,6 +84,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.0",
batch_size: int = None,
*,
shuffle: bool = True,
@@ -99,6 +100,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
super().__init__(
dataset_id,
version,
batch_size,
shuffle=shuffle,
root=root,

View File

@@ -87,6 +87,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.0",
batch_size: int = None,
*,
shuffle: bool = True,
@@ -100,6 +101,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
):
super().__init__(
dataset_id,
version,
batch_size,
shuffle=shuffle,
root=root,

View File

@@ -40,6 +40,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
def __init__(
self,
dataset_id: str,
version: str | None = None,
batch_size: int = None,
*,
shuffle: bool = True,
@@ -53,6 +54,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
):
super().__init__(
dataset_id,
version,
batch_size,
shuffle=shuffle,
root=root,