fix more bugs in normalization

This commit is contained in:
Cadene
2024-03-11 11:03:13 +00:00
parent a7ef4a6a33
commit 816b2e9d63
5 changed files with 14 additions and 8 deletions

View File

@@ -87,9 +87,11 @@ def make_offline_buffer(
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
# we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following)
in_keys = [("observation", "state"), ("action")]
if cfg.policy == "tdmpc":
if cfg.policy.name == "tdmpc":
for key in offline_buffer.image_keys:
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(key)
@@ -97,7 +99,7 @@ def make_offline_buffer(
in_keys.append(("next", *key))
in_keys.append(("next", "observation", "state"))
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)