- Refactor observation encoder in modeling_sac.py

- added `torch.compile` to the actor and learner servers.
- organized imports in `train_sac.py`
- optimized the parameters push by not sending the frozen pre-trained encoder.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-01-31 16:45:52 +00:00
committed by AdilZouitine
parent faab32fe14
commit b29401e4e2
6 changed files with 199 additions and 85 deletions

View File

@@ -191,6 +191,7 @@ def act_with_policy(cfg: DictConfig):
# pretrained_policy_name_or_path=None,
# device=device,
# )
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
# HACK for maniskill
@@ -237,7 +238,9 @@ def act_with_policy(cfg: DictConfig):
logging.debug("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.actor.load_state_dict(state_dict)
# strict=False for the case when the image encoder is frozen and not sent through
# the network. Becareful might cause issues if the wrong keys are passed
policy.actor.load_state_dict(state_dict, strict=False)
if len(list_transition_to_send_to_learner) > 0:
logging.debug(