This commit is contained in:
Alexander Soare
2024-03-27 18:33:48 +00:00
parent 120f0aef5c
commit b7c9c33072
10 changed files with 20 additions and 33 deletions

View File

@@ -59,8 +59,6 @@ class AbstractDataset(TensorDictReplayBuffer):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
# Don't actually load any data. This is a stand-in solution to get the transforms.
dummy: bool = False,
):
assert (
self.available_datasets is not None
@@ -79,7 +77,7 @@ class AbstractDataset(TensorDictReplayBuffer):
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
)
storage = self._download_or_load_dataset() if not dummy else None
storage = self._download_or_load_dataset()
super().__init__(
storage=storage,

View File

@@ -97,7 +97,6 @@ class AlohaDataset(AbstractDataset):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
):
super().__init__(
dataset_id,
@@ -111,7 +110,6 @@ class AlohaDataset(AbstractDataset):
collate_fn=collate_fn,
writer=writer,
transform=transform,
dummy=dummy,
)
@property

View File

@@ -21,12 +21,7 @@ def make_offline_buffer(
overwrite_batch_size=None,
overwrite_prefetch=None,
stats_path=None,
# Don't actually load any data. This is a stand-in solution to get the transforms.
dummy=False,
):
if dummy and normalize and stats_path is None:
raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.")
if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0
batch_size = None
@@ -93,7 +88,6 @@ def make_offline_buffer(
root=DATA_DIR,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
dummy=dummy,
)
if cfg.policy.name == "tdmpc":

View File

@@ -100,7 +100,6 @@ class PushtDataset(AbstractDataset):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
):
super().__init__(
dataset_id,
@@ -114,7 +113,6 @@ class PushtDataset(AbstractDataset):
collate_fn=collate_fn,
writer=writer,
transform=transform,
dummy=dummy,
)
def _download_and_preproc_obsolete(self):

View File

@@ -51,7 +51,6 @@ class SimxarmDataset(AbstractDataset):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
):
super().__init__(
dataset_id,
@@ -65,7 +64,6 @@ class SimxarmDataset(AbstractDataset):
collate_fn=collate_fn,
writer=writer,
transform=transform,
dummy=dummy,
)
def _download_and_preproc_obsolete(self):