Refactor SAC configuration and policy for improved action sampling and stability
- Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions. - Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior. - Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability. These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
This commit is contained in:
committed by
Michel Aractingi
parent
22fbc9ea4a
commit
5b4adc00bb
@@ -53,30 +53,13 @@ class SACConfig:
|
||||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
}
|
||||
actor_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
}
|
||||
policy_kwargs = {
|
||||
"tanh_squash_distribution": True,
|
||||
"std_parameterization": "softplus",
|
||||
"std_min": 0.005,
|
||||
"std_max": 5.0,
|
||||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
}
|
||||
)
|
||||
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
network_hidden_dims: int = 256
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user