Cleanup policies imports

This commit is contained in:
Simon Alibert
2025-03-08 11:08:40 +01:00
parent 6909b62a21
commit 80387285bb
15 changed files with 53 additions and 48 deletions

View File

@@ -24,7 +24,7 @@ from lerobot.common.datasets.lerobot_dataset import (
MultiLeRobotDataset,
)
from lerobot.common.datasets.transforms import ImageTransforms
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
IMAGENET_STATS = {

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from . import act, diffusion, pi0, tdmpc, vqbet
from .config import PreTrainedConfig
from .factory import make_policy
from .pretrained import PreTrainedPolicy
__all__ = ["act", "diffusion", "pi0", "tdmpc", "vqbet", "make_policy"]
__all__ = ["act", "diffusion", "pi0", "tdmpc", "vqbet", "make_policy", "PreTrainedConfig", "PreTrainedPolicy"]

View File

@@ -1,4 +1,4 @@
from .configuration_act import ACTConfig
from .modeling_act import ACT
from .modeling_act import ACTPolicy
__all__ = ["ACTConfig", "ACT"]
__all__ = ["ACTConfig", "ACTPolicy"]

View File

@@ -22,53 +22,59 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.types import FeatureType
from .config import PreTrainedConfig
from .pretrained import PreTrainedPolicy
def get_policy_class(name: str) -> PreTrainedPolicy:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc":
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy
elif name == "diffusion":
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy
elif name == "act":
from lerobot.common.policies.act.modeling_act import ACTPolicy
if name == "act":
from .act import ACTPolicy
return ACTPolicy
elif name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
elif name == "diffusion":
from .diffusion import DiffusionPolicy
return VQBeTPolicy
return DiffusionPolicy
elif name == "pi0":
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
from .pi0 import PI0Policy
return PI0Policy
elif name == "tdmpc":
from .tdmpc import TDMPCPolicy
return TDMPCPolicy
elif name == "vqbet":
from .vqbet import VQBeTPolicy
return VQBeTPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
if policy_type == "tdmpc":
return TDMPCConfig(**kwargs)
elif policy_type == "diffusion":
return DiffusionConfig(**kwargs)
elif policy_type == "act":
if policy_type == "act":
from .act import ACTConfig
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "diffusion":
from .diffusion import DiffusionConfig
return DiffusionConfig(**kwargs)
elif policy_type == "pi0":
from .pi0 import PI0Config
return PI0Config(**kwargs)
elif policy_type == "tdmpc":
from .tdmpc import TDMPCConfig
return TDMPCConfig(**kwargs)
elif policy_type == "vqbet":
from .vqbet import VQBeTConfig
return VQBeTConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")

View File

@@ -15,8 +15,7 @@
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies import PreTrainedConfig, make_policy
torch.backends.cudnn.benchmark = True

View File

@@ -19,8 +19,7 @@ from pathlib import Path
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies import PreTrainedConfig, make_policy
def display(tensor: torch.Tensor):

View File

@@ -26,9 +26,10 @@ from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.utils.hub import HubMixin
from .config import PreTrainedConfig
T = TypeVar("T", bound="PreTrainedPolicy")
DEFAULT_POLICY_CARD = """

View File

@@ -17,7 +17,7 @@ from pathlib import Path
import draccus
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies import PreTrainedConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.configs import parser

View File

@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
from pathlib import Path
from lerobot.common import envs, policies # noqa: F401
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies import PreTrainedConfig
from lerobot.configs import parser
from lerobot.configs.default import EvalConfig

View File

@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot.common import envs
from lerobot.common.optim import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies import PreTrainedConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig

View File

@@ -141,7 +141,7 @@ from pprint import pformat
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
ControlPipelineConfig,

View File

@@ -67,8 +67,7 @@ from tqdm import trange
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies import PreTrainedPolicy, make_policy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.random_utils import set_seed

View File

@@ -29,8 +29,7 @@ from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies import PreTrainedPolicy, make_policy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.common.utils.random_utils import set_seed

View File

@@ -21,7 +21,8 @@ from safetensors.torch import save_file
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy, make_policy_config
from lerobot.common.policies import make_policy
from lerobot.common.policies.factory import make_policy_config
from lerobot.common.utils.random_utils import set_seed
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig

View File

@@ -41,9 +41,8 @@ from unittest.mock import patch
import pytest
from lerobot.common.policies import PreTrainedConfig, make_policy
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
RecordControlConfig,