fix more bugs in normalization
This commit is contained in:
@@ -87,9 +87,11 @@ def make_offline_buffer(
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
|
||||
# we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following)
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy == "tdmpc":
|
||||
if cfg.policy.name == "tdmpc":
|
||||
for key in offline_buffer.image_keys:
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
|
||||
in_keys.append(key)
|
||||
@@ -97,7 +99,7 @@ def make_offline_buffer(
|
||||
in_keys.append(("next", *key))
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
|
||||
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user