From 80387285bb23b1565854e8bfe6e5f363d106e8a2 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sat, 8 Mar 2025 11:08:40 +0100 Subject: [PATCH] Cleanup policies imports --- lerobot/common/datasets/factory.py | 2 +- lerobot/common/policies/__init__.py | 4 +- lerobot/common/policies/act/__init__.py | 4 +- lerobot/common/policies/factory.py | 62 ++++++++++--------- .../pi0/conversion_scripts/benchmark.py | 3 +- .../conversion_scripts/compare_with_jax.py | 3 +- lerobot/common/policies/pretrained.py | 3 +- .../common/robot_devices/control_configs.py | 2 +- lerobot/configs/eval.py | 2 +- lerobot/configs/train.py | 2 +- lerobot/scripts/control_robot.py | 2 +- lerobot/scripts/eval.py | 3 +- lerobot/scripts/train.py | 3 +- tests/scripts/save_policy_to_safetensors.py | 3 +- tests/test_control_robot.py | 3 +- 15 files changed, 53 insertions(+), 48 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 4e1c33863..ec8e4b936 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -24,7 +24,7 @@ from lerobot.common.datasets.lerobot_dataset import ( MultiLeRobotDataset, ) from lerobot.common.datasets.transforms import ImageTransforms -from lerobot.common.policies.config import PreTrainedConfig +from lerobot.common.policies import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig IMAGENET_STATS = { diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py index 8b3c25194..45b78d9ec 100644 --- a/lerobot/common/policies/__init__.py +++ b/lerobot/common/policies/__init__.py @@ -13,6 +13,8 @@ # limitations under the License. from . import act, diffusion, pi0, tdmpc, vqbet +from .config import PreTrainedConfig from .factory import make_policy +from .pretrained import PreTrainedPolicy -__all__ = ["act", "diffusion", "pi0", "tdmpc", "vqbet", "make_policy"] +__all__ = ["act", "diffusion", "pi0", "tdmpc", "vqbet", "make_policy", "PreTrainedConfig", "PreTrainedPolicy"] diff --git a/lerobot/common/policies/act/__init__.py b/lerobot/common/policies/act/__init__.py index 8b3b2ba2d..72ca98306 100644 --- a/lerobot/common/policies/act/__init__.py +++ b/lerobot/common/policies/act/__init__.py @@ -1,4 +1,4 @@ from .configuration_act import ACTConfig -from .modeling_act import ACT +from .modeling_act import ACTPolicy -__all__ = ["ACTConfig", "ACT"] +__all__ = ["ACTConfig", "ACTPolicy"] diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 8caac336b..05a8026e3 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -22,53 +22,59 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.common.datasets.utils import dataset_to_policy_features from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.utils import env_to_policy_features -from lerobot.common.policies.act.configuration_act import ACTConfig -from lerobot.common.policies.config import PreTrainedConfig -from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig -from lerobot.common.policies.pi0.configuration_pi0 import PI0Config -from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig -from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.configs.types import FeatureType +from .config import PreTrainedConfig +from .pretrained import PreTrainedPolicy + def get_policy_class(name: str) -> PreTrainedPolicy: """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" - if name == "tdmpc": - from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy - - return TDMPCPolicy - elif name == "diffusion": - from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy - - return DiffusionPolicy - elif name == "act": - from lerobot.common.policies.act.modeling_act import ACTPolicy + if name == "act": + from .act import ACTPolicy return ACTPolicy - elif name == "vqbet": - from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy + elif name == "diffusion": + from .diffusion import DiffusionPolicy - return VQBeTPolicy + return DiffusionPolicy elif name == "pi0": - from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy + from .pi0 import PI0Policy return PI0Policy + elif name == "tdmpc": + from .tdmpc import TDMPCPolicy + + return TDMPCPolicy + elif name == "vqbet": + from .vqbet import VQBeTPolicy + + return VQBeTPolicy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: - if policy_type == "tdmpc": - return TDMPCConfig(**kwargs) - elif policy_type == "diffusion": - return DiffusionConfig(**kwargs) - elif policy_type == "act": + if policy_type == "act": + from .act import ACTConfig + return ACTConfig(**kwargs) - elif policy_type == "vqbet": - return VQBeTConfig(**kwargs) + elif policy_type == "diffusion": + from .diffusion import DiffusionConfig + + return DiffusionConfig(**kwargs) elif policy_type == "pi0": + from .pi0 import PI0Config + return PI0Config(**kwargs) + elif policy_type == "tdmpc": + from .tdmpc import TDMPCConfig + + return TDMPCConfig(**kwargs) + elif policy_type == "vqbet": + from .vqbet import VQBeTConfig + + return VQBeTConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") diff --git a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py index 41cb938ff..625dde910 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py +++ b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py @@ -15,8 +15,7 @@ import torch from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.policies.config import PreTrainedConfig -from lerobot.common.policies.factory import make_policy +from lerobot.common.policies import PreTrainedConfig, make_policy torch.backends.cudnn.benchmark = True diff --git a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py index b0df3084d..a0bb8526e 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py @@ -19,8 +19,7 @@ from pathlib import Path import torch from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata -from lerobot.common.policies.config import PreTrainedConfig -from lerobot.common.policies.factory import make_policy +from lerobot.common.policies import PreTrainedConfig, make_policy def display(tensor: torch.Tensor): diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py index df6f491d5..5ddf02c50 100644 --- a/lerobot/common/policies/pretrained.py +++ b/lerobot/common/policies/pretrained.py @@ -26,9 +26,10 @@ from safetensors.torch import load_model as load_model_as_safetensor from safetensors.torch import save_model as save_model_as_safetensor from torch import Tensor, nn -from lerobot.common.policies.config import PreTrainedConfig from lerobot.common.utils.hub import HubMixin +from .config import PreTrainedConfig + T = TypeVar("T", bound="PreTrainedPolicy") DEFAULT_POLICY_CARD = """ diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py index c1e69b4a3..7bfd166de 100644 --- a/lerobot/common/robot_devices/control_configs.py +++ b/lerobot/common/robot_devices/control_configs.py @@ -17,7 +17,7 @@ from pathlib import Path import draccus -from lerobot.common.policies.config import PreTrainedConfig +from lerobot.common.policies import PreTrainedConfig from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.configs import parser diff --git a/lerobot/configs/eval.py b/lerobot/configs/eval.py index fd6ebdb8a..b13de8fd1 100644 --- a/lerobot/configs/eval.py +++ b/lerobot/configs/eval.py @@ -18,7 +18,7 @@ from dataclasses import dataclass, field from pathlib import Path from lerobot.common import envs, policies # noqa: F401 -from lerobot.common.policies.config import PreTrainedConfig +from lerobot.common.policies import PreTrainedConfig from lerobot.configs import parser from lerobot.configs.default import EvalConfig diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index cb1890348..c9ce60164 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError from lerobot.common import envs from lerobot.common.optim import OptimizerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig -from lerobot.common.policies.config import PreTrainedConfig +from lerobot.common.policies import PreTrainedConfig from lerobot.common.utils.hub import HubMixin from lerobot.configs import parser from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 3c3c43f91..b94f5b00a 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -141,7 +141,7 @@ from pprint import pformat # from safetensors.torch import load_file, save_file from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.policies.factory import make_policy +from lerobot.common.policies import make_policy from lerobot.common.robot_devices.control_configs import ( CalibrateControlConfig, ControlPipelineConfig, diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index d7a4201f2..c54bccf77 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -67,8 +67,7 @@ from tqdm import trange from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import preprocess_observation -from lerobot.common.policies.factory import make_policy -from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies import PreTrainedPolicy, make_policy from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.io_utils import write_video from lerobot.common.utils.random_utils import set_seed diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f2b1e29e3..5559e7b39 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -29,8 +29,7 @@ from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.optim.factory import make_optimizer_and_scheduler -from lerobot.common.policies.factory import make_policy -from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies import PreTrainedPolicy, make_policy from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.common.utils.random_utils import set_seed diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 60fd9fc05..4f61bbd4e 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -21,7 +21,8 @@ from safetensors.torch import save_file from lerobot.common.datasets.factory import make_dataset from lerobot.common.optim.factory import make_optimizer_and_scheduler -from lerobot.common.policies.factory import make_policy, make_policy_config +from lerobot.common.policies import make_policy +from lerobot.common.policies.factory import make_policy_config from lerobot.common.utils.random_utils import set_seed from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 32d8cc638..77b73f32a 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -41,9 +41,8 @@ from unittest.mock import patch import pytest +from lerobot.common.policies import PreTrainedConfig, make_policy from lerobot.common.policies.act.configuration_act import ACTConfig -from lerobot.common.policies.config import PreTrainedConfig -from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.control_configs import ( CalibrateControlConfig, RecordControlConfig,