finish examples 2 and 3
This commit is contained in:
@@ -59,6 +59,8 @@ class AbstractDataset(TensorDictReplayBuffer):
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
||||
dummy: bool = False,
|
||||
):
|
||||
assert (
|
||||
self.available_datasets is not None
|
||||
@@ -77,7 +79,7 @@ class AbstractDataset(TensorDictReplayBuffer):
|
||||
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
||||
)
|
||||
|
||||
storage = self._download_or_load_dataset()
|
||||
storage = self._download_or_load_dataset() if not dummy else None
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
|
||||
@@ -97,6 +97,7 @@ class AlohaDataset(AbstractDataset):
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
@@ -110,6 +111,7 @@ class AlohaDataset(AbstractDataset):
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -21,7 +21,12 @@ def make_offline_buffer(
|
||||
overwrite_batch_size=None,
|
||||
overwrite_prefetch=None,
|
||||
stats_path=None,
|
||||
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
||||
dummy=False,
|
||||
):
|
||||
if dummy and normalize and stats_path is None:
|
||||
raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.")
|
||||
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
@@ -93,6 +98,7 @@ def make_offline_buffer(
|
||||
root=DATA_DIR,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
|
||||
@@ -100,6 +100,7 @@ class PushtDataset(AbstractDataset):
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
@@ -113,6 +114,7 @@ class PushtDataset(AbstractDataset):
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
|
||||
@@ -51,6 +51,7 @@ class SimxarmDataset(AbstractDataset):
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
@@ -64,6 +65,7 @@ class SimxarmDataset(AbstractDataset):
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
|
||||
Reference in New Issue
Block a user