finish examples 2 and 3

This commit is contained in:
Alexander Soare
2024-03-26 16:13:40 +00:00
parent cb6d1e0871
commit 1ed0110900
10 changed files with 196 additions and 42 deletions

View File

@@ -59,6 +59,8 @@ class AbstractDataset(TensorDictReplayBuffer):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
# Don't actually load any data. This is a stand-in solution to get the transforms.
dummy: bool = False,
):
assert (
self.available_datasets is not None
@@ -77,7 +79,7 @@ class AbstractDataset(TensorDictReplayBuffer):
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
)
storage = self._download_or_load_dataset()
storage = self._download_or_load_dataset() if not dummy else None
super().__init__(
storage=storage,

View File

@@ -97,6 +97,7 @@ class AlohaDataset(AbstractDataset):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
):
super().__init__(
dataset_id,
@@ -110,6 +111,7 @@ class AlohaDataset(AbstractDataset):
collate_fn=collate_fn,
writer=writer,
transform=transform,
dummy=dummy,
)
@property

View File

@@ -21,7 +21,12 @@ def make_offline_buffer(
overwrite_batch_size=None,
overwrite_prefetch=None,
stats_path=None,
# Don't actually load any data. This is a stand-in solution to get the transforms.
dummy=False,
):
if dummy and normalize and stats_path is None:
raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.")
if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0
batch_size = None
@@ -93,6 +98,7 @@ def make_offline_buffer(
root=DATA_DIR,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
dummy=dummy,
)
if cfg.policy.name == "tdmpc":

View File

@@ -100,6 +100,7 @@ class PushtDataset(AbstractDataset):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
):
super().__init__(
dataset_id,
@@ -113,6 +114,7 @@ class PushtDataset(AbstractDataset):
collate_fn=collate_fn,
writer=writer,
transform=transform,
dummy=dummy,
)
def _download_and_preproc_obsolete(self):

View File

@@ -51,6 +51,7 @@ class SimxarmDataset(AbstractDataset):
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
):
super().__init__(
dataset_id,
@@ -64,6 +65,7 @@ class SimxarmDataset(AbstractDataset):
collate_fn=collate_fn,
writer=writer,
transform=transform,
dummy=dummy,
)
def _download_and_preproc_obsolete(self):