- Refactor observation encoder in modeling_sac.py

- added `torch.compile` to the actor and learner servers.
- organized imports in `train_sac.py`
- optimized the parameters push by not sending the frozen pre-trained encoder.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-01-31 16:45:52 +00:00
committed by AdilZouitine
parent faab32fe14
commit b29401e4e2
6 changed files with 199 additions and 85 deletions

View File

@@ -55,9 +55,10 @@ class SACConfig:
)
camera_number: int = 1
# Add type annotations for these fields:
vision_encoder_name: str = field(default="microsoft/resnet-18")
vision_encoder_name: str | None = field(default="microsoft/resnet-18")
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32
shared_encoder: bool = False
shared_encoder: bool = True
discount: float = 0.99
temperature_init: float = 1.0
num_critics: int = 2