1, add input normalization in configuration_sac.py 2, add masking on loss computation

This commit is contained in:
KeWang1017
2024-12-30 18:46:25 +00:00
committed by Ke-Wang1017
parent 35de91ef2b
commit 77a7f92139
3 changed files with 26 additions and 26 deletions

View File

@@ -28,12 +28,18 @@ class SACConfig:
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
"action": [2],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)