Merge remote-tracking branch 'upstream/main' into fix_environment_seeding

This commit is contained in:
Alexander Soare
2024-03-22 15:11:15 +00:00
5 changed files with 137 additions and 27 deletions

View File

@@ -14,7 +14,12 @@ DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_offline_buffer(
cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
cfg,
overwrite_sampler=None,
normalize=True,
overwrite_batch_size=None,
overwrite_prefetch=None,
stats_path=None,
):
if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0
@@ -98,10 +103,12 @@ def make_offline_buffer(
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
# we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following)
# we only normalize the state and action, since the images are usually normalized inside the model for
# now (except for tdmpc: see the following)
in_keys = [("observation", "state"), ("action")]
if cfg.policy.name == "tdmpc":