- Refactor observation encoder in modeling_sac.py
- added `torch.compile` to the actor and learner servers. - organized imports in `train_sac.py` - optimized the parameters push by not sending the frozen pre-trained encoder. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
committed by
AdilZouitine
parent
faab32fe14
commit
b29401e4e2
@@ -191,6 +191,7 @@ def act_with_policy(cfg: DictConfig):
|
||||
# pretrained_policy_name_or_path=None,
|
||||
# device=device,
|
||||
# )
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# HACK for maniskill
|
||||
@@ -237,7 +238,9 @@ def act_with_policy(cfg: DictConfig):
|
||||
logging.debug("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = parameters_queue.get()
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.actor.load_state_dict(state_dict)
|
||||
# strict=False for the case when the image encoder is frozen and not sent through
|
||||
# the network. Becareful might cause issues if the wrong keys are passed
|
||||
policy.actor.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
logging.debug(
|
||||
|
||||
Reference in New Issue
Block a user