Make sure targets are normalized too (#106)

This commit is contained in:
Alexander Soare
2024-04-26 11:18:39 +01:00
committed by GitHub
parent b980c5dd9e
commit 45f351c618
8 changed files with 116 additions and 92 deletions

View File

@@ -83,17 +83,13 @@ class DiffusionConfig:
)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
# Architecture / modeling.
# Vision backbone.