forked from tangger/lerobot
- Refactor observation encoder in modeling_sac.py
- added `torch.compile` to the actor and learner servers. - organized imports in `train_sac.py` - optimized the parameters push by not sending the frozen pre-trained encoder. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -191,6 +191,7 @@ def act_with_policy(cfg: DictConfig):
|
||||
# pretrained_policy_name_or_path=None,
|
||||
# device=device,
|
||||
# )
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# HACK for maniskill
|
||||
@@ -237,7 +238,9 @@ def act_with_policy(cfg: DictConfig):
|
||||
logging.debug("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = parameters_queue.get()
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.actor.load_state_dict(state_dict)
|
||||
# strict=False for the case when the image encoder is frozen and not sent through
|
||||
# the network. Becareful might cause issues if the wrong keys are passed
|
||||
policy.actor.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
logging.debug(
|
||||
|
||||
@@ -259,6 +259,9 @@ def learner_push_parameters(
|
||||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
|
||||
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
buf = io.BytesIO()
|
||||
@@ -541,6 +544,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
|
||||
@@ -13,26 +13,25 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import functools
|
||||
from pprint import pformat
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional, Sequence, TypedDict, Callable
|
||||
from pprint import pformat
|
||||
from typing import Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
|
||||
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
Reference in New Issue
Block a user