trying to get sac running

This commit is contained in:
KeWang1017
2024-12-26 23:38:46 +00:00
committed by AdilZouitine
parent 80b86e9bc3
commit a113daa81e
3 changed files with 149 additions and 40 deletions

View File

@@ -20,6 +20,24 @@ from dataclasses import dataclass, field
@dataclass
class SACConfig:
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
discount = 0.99
temperature_init = 1.0
num_critics = 2
@@ -29,6 +47,9 @@ class SACConfig:
temperature_lr = 3e-4
critic_target_update_weight = 0.005
utd_ratio = 2
state_encoder_hidden_dim = 256
latent_dim = 50
target_entropy = None
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,