Added functions for converting the replay buffer from and to LeRobotDataset. When we want to save the replay buffer, we convert it first to LeRobotDataset format and save it locally and vice-versa. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
123 lines
5.4 KiB
Python
123 lines
5.4 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import inspect
|
|
import logging
|
|
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
|
from lerobot.common.policies.policy_protocol import Policy
|
|
from lerobot.common.utils.utils import get_safe_torch_device
|
|
|
|
|
|
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
|
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
|
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
|
|
logging.warning(
|
|
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
|
)
|
|
|
|
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
|
|
# issues with mutable defaults. This filter changes all lists to tuples.
|
|
def list_to_tuple(item):
|
|
return tuple(item) if isinstance(item, list) else item
|
|
|
|
policy_cfg = policy_cfg_class(
|
|
**{
|
|
k: list_to_tuple(v)
|
|
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
|
|
if k in expected_kwargs
|
|
}
|
|
)
|
|
return policy_cfg
|
|
|
|
|
|
def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
|
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
|
if name == "tdmpc":
|
|
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
|
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
|
|
|
return TDMPCPolicy, TDMPCConfig
|
|
elif name == "diffusion":
|
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
|
|
|
return DiffusionPolicy, DiffusionConfig
|
|
elif name == "act":
|
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
|
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
|
|
|
return ACTPolicy, ACTConfig
|
|
elif name == "vqbet":
|
|
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
|
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
|
|
|
return VQBeTPolicy, VQBeTConfig
|
|
elif name == "sac":
|
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
|
|
|
return SACPolicy, SACConfig
|
|
else:
|
|
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
|
|
|
|
|
def make_policy(
|
|
hydra_cfg: DictConfig,
|
|
pretrained_policy_name_or_path: str | None = None,
|
|
dataset_stats=None,
|
|
*args,
|
|
**kwargs,
|
|
) -> Policy:
|
|
"""Make an instance of a policy class.
|
|
|
|
Args:
|
|
hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is
|
|
provided, only `hydra_cfg.policy.name` is used while everything else is ignored.
|
|
pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a
|
|
directory containing weights saved using `Policy.save_pretrained`. Note that providing this
|
|
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`.
|
|
dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must
|
|
be provided when initializing a new policy, and must not be provided when loading a pretrained
|
|
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
|
"""
|
|
# if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
|
# raise ValueError(
|
|
# "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
|
# )
|
|
|
|
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
|
|
|
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
|
if pretrained_policy_name_or_path is None:
|
|
# Make a fresh policy.
|
|
# HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments
|
|
# for example device for the sac policy.
|
|
policy = policy_cls(config=policy_cfg, dataset_stats=dataset_stats)
|
|
else:
|
|
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
|
# hyperparameters that we want to vary).
|
|
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
|
|
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
|
|
# huggingface_hub should make it possible to avoid the hack:
|
|
# https://github.com/huggingface/huggingface_hub/pull/2274.
|
|
policy = policy_cls(policy_cfg)
|
|
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
|
|
|
policy.to(get_safe_torch_device(hydra_cfg.device))
|
|
|
|
return policy
|