- Refactor observation encoder in modeling_sac.py

- added `torch.compile` to the actor and learner servers.
- organized imports in `train_sac.py`
- optimized the parameters push by not sending the frozen pre-trained encoder.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-01-31 16:45:52 +00:00
parent f1c8bfe01e
commit 506821c7df
6 changed files with 199 additions and 85 deletions

View File

@@ -191,6 +191,7 @@ def act_with_policy(cfg: DictConfig):
# pretrained_policy_name_or_path=None,
# device=device,
# )
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
# HACK for maniskill
@@ -237,7 +238,9 @@ def act_with_policy(cfg: DictConfig):
logging.debug("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.actor.load_state_dict(state_dict)
# strict=False for the case when the image encoder is frozen and not sent through
# the network. Becareful might cause issues if the wrong keys are passed
policy.actor.load_state_dict(state_dict, strict=False)
if len(list_transition_to_send_to_learner) > 0:
logging.debug(

View File

@@ -259,6 +259,9 @@ def learner_push_parameters(
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
buf = io.BytesIO()
@@ -541,6 +544,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)

View File

@@ -13,26 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import functools
from pprint import pformat
import logging
import random
from typing import Optional, Sequence, TypedDict, Callable
from pprint import pformat
from typing import Callable, Optional, Sequence, TypedDict
import hydra
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from torch import nn
from tqdm import tqdm
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy