This commit is contained in:
Alexander Soare
2024-04-03 09:56:46 +01:00
parent a6ec4fbf58
commit caf4ffcf65

View File

@@ -152,7 +152,13 @@ class AbstractDataset(TensorDictReplayBuffer):
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
def _compute_stats(self, batch_size: int = 32):
"""Compute dataset statistics including minimum, maximum, mean, and standard deviation."""
"""Compute dataset statistics including minimum, maximum, mean, and standard deviation.
TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the
full dataset (for handling very large datasets). The sampling would then have to be random
(preferably without replacement). Both stats computation loops would ideally sample the same
items.
"""
rb = TensorDictReplayBuffer(
storage=self._storage,
batch_size=32,