forked from tangger/lerobot
- Refactor observation encoder in modeling_sac.py
- added `torch.compile` to the actor and learner servers. - organized imports in `train_sac.py` - optimized the parameters push by not sending the frozen pre-trained encoder. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
committed by
AdilZouitine
parent
faab32fe14
commit
b29401e4e2
@@ -259,6 +259,9 @@ def learner_push_parameters(
|
||||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
|
||||
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
buf = io.BytesIO()
|
||||
@@ -541,6 +544,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
|
||||
Reference in New Issue
Block a user