- Refactor observation encoder in modeling_sac.py

- added `torch.compile` to the actor and learner servers.
- organized imports in `train_sac.py`
- optimized the parameters push by not sending the frozen pre-trained encoder.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-01-31 16:45:52 +00:00
committed by AdilZouitine
parent faab32fe14
commit b29401e4e2
6 changed files with 199 additions and 85 deletions

View File

@@ -259,6 +259,9 @@ def learner_push_parameters(
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
buf = io.BytesIO()
@@ -541,6 +544,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)