Refactor SAC configuration and policy for improved action sampling and stability

- Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions.
- Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior.
- Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
This commit is contained in:
KeWang1017
2024-12-29 12:30:39 +00:00
committed by Michel Aractingi
parent 22fbc9ea4a
commit 5b4adc00bb
2 changed files with 43 additions and 217 deletions

View File

@@ -53,30 +53,13 @@ class SACConfig:
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
}
actor_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
}
policy_kwargs = {
"tanh_squash_distribution": True,
"std_parameterization": "softplus",
"std_min": 0.005,
"std_max": 5.0,
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
state_encoder_hidden_dim: int = 256
latent_dim: int = 256
network_hidden_dims: int = 256
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)