finish examples 2 and 3

This commit is contained in:
Alexander Soare
2024-03-26 16:13:40 +00:00
parent cb6d1e0871
commit 1ed0110900
10 changed files with 196 additions and 42 deletions

View File

@@ -21,7 +21,12 @@ def make_offline_buffer(
overwrite_batch_size=None,
overwrite_prefetch=None,
stats_path=None,
# Don't actually load any data. This is a stand-in solution to get the transforms.
dummy=False,
):
if dummy and normalize and stats_path is None:
raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.")
if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0
batch_size = None
@@ -93,6 +98,7 @@ def make_offline_buffer(
root=DATA_DIR,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
dummy=dummy,
)
if cfg.policy.name == "tdmpc":