diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 38c01b42..4e1c3386 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -24,7 +24,7 @@ from lerobot.common.datasets.lerobot_dataset import ( MultiLeRobotDataset, ) from lerobot.common.datasets.transforms import ImageTransforms -from lerobot.configs.policies import PreTrainedConfig +from lerobot.common.policies.config import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig IMAGENET_STATS = { diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 7a5819b7..bce1a52d 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -16,9 +16,10 @@ from dataclasses import dataclass, field from lerobot.common.optim.optimizers import AdamWConfig -from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode +from ..config import PreTrainedConfig + @PreTrainedConfig.register_subclass("act") @dataclass diff --git a/lerobot/configs/policies.py b/lerobot/common/policies/config.py similarity index 100% rename from lerobot/configs/policies.py rename to lerobot/common/policies/config.py diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index e73c65fe..d16bbddd 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -18,9 +18,10 @@ from dataclasses import dataclass, field from lerobot.common.optim.optimizers import AdamConfig from lerobot.common.optim.schedulers import DiffuserSchedulerConfig -from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode +from ..config import PreTrainedConfig + @PreTrainedConfig.register_subclass("diffusion") @dataclass diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 5d2f6cb5..8caac336 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -23,12 +23,12 @@ from lerobot.common.datasets.utils import dataset_to_policy_features from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig +from lerobot.common.policies.config import PreTrainedConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig -from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType diff --git a/lerobot/common/policies/pi0/configuration_pi0.py b/lerobot/common/policies/pi0/configuration_pi0.py index 8c7cc130..63ce3dc4 100644 --- a/lerobot/common/policies/pi0/configuration_pi0.py +++ b/lerobot/common/policies/pi0/configuration_pi0.py @@ -18,9 +18,10 @@ from lerobot.common.optim.optimizers import AdamWConfig from lerobot.common.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) -from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from ..config import PreTrainedConfig + @PreTrainedConfig.register_subclass("pi0") @dataclass diff --git a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py index cb3c0e9b..41cb938f 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py +++ b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py @@ -15,8 +15,8 @@ import torch from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.config import PreTrainedConfig from lerobot.common.policies.factory import make_policy -from lerobot.configs.policies import PreTrainedConfig torch.backends.cudnn.benchmark = True diff --git a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py index 6bd7c91f..b0df3084 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py @@ -19,8 +19,8 @@ from pathlib import Path import torch from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.common.policies.config import PreTrainedConfig from lerobot.common.policies.factory import make_policy -from lerobot.configs.policies import PreTrainedConfig def display(tensor: torch.Tensor): diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py index da4ef157..df6f491d 100644 --- a/lerobot/common/policies/pretrained.py +++ b/lerobot/common/policies/pretrained.py @@ -26,8 +26,8 @@ from safetensors.torch import load_model as load_model_as_safetensor from safetensors.torch import save_model as save_model_as_safetensor from torch import Tensor, nn +from lerobot.common.policies.config import PreTrainedConfig from lerobot.common.utils.hub import HubMixin -from lerobot.configs.policies import PreTrainedConfig T = TypeVar("T", bound="PreTrainedPolicy") diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index 3fce01df..df4ddd21 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -17,9 +17,10 @@ from dataclasses import dataclass, field from lerobot.common.optim.optimizers import AdamConfig -from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode +from ..config import PreTrainedConfig + @PreTrainedConfig.register_subclass("tdmpc") @dataclass diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index 28e9c433..7d4fcace 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -20,9 +20,10 @@ from dataclasses import dataclass, field from lerobot.common.optim.optimizers import AdamConfig from lerobot.common.optim.schedulers import VQBeTSchedulerConfig -from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode +from ..config import PreTrainedConfig + @PreTrainedConfig.register_subclass("vqbet") @dataclass diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py index 0ecd8683..c1e69b4a 100644 --- a/lerobot/common/robot_devices/control_configs.py +++ b/lerobot/common/robot_devices/control_configs.py @@ -17,9 +17,9 @@ from pathlib import Path import draccus +from lerobot.common.policies.config import PreTrainedConfig from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.configs import parser -from lerobot.configs.policies import PreTrainedConfig @dataclass diff --git a/lerobot/configs/eval.py b/lerobot/configs/eval.py index 16b35291..fd6ebdb8 100644 --- a/lerobot/configs/eval.py +++ b/lerobot/configs/eval.py @@ -18,9 +18,9 @@ from dataclasses import dataclass, field from pathlib import Path from lerobot.common import envs, policies # noqa: F401 +from lerobot.common.policies.config import PreTrainedConfig from lerobot.configs import parser from lerobot.configs.default import EvalConfig -from lerobot.configs.policies import PreTrainedConfig @dataclass diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 2b147a5b..cb189034 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -24,10 +24,10 @@ from huggingface_hub.errors import HfHubHTTPError from lerobot.common import envs from lerobot.common.optim import OptimizerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig +from lerobot.common.policies.config import PreTrainedConfig from lerobot.common.utils.hub import HubMixin from lerobot.configs import parser from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig -from lerobot.configs.policies import PreTrainedConfig TRAIN_CONFIG_NAME = "train_config.json" diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 02041e30..32d8cc63 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -42,6 +42,7 @@ from unittest.mock import patch import pytest from lerobot.common.policies.act.configuration_act import ACTConfig +from lerobot.common.policies.config import PreTrainedConfig from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.control_configs import ( CalibrateControlConfig, @@ -49,7 +50,6 @@ from lerobot.common.robot_devices.control_configs import ( ReplayControlConfig, TeleoperateControlConfig, ) -from lerobot.configs.policies import PreTrainedConfig from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate from tests.test_robots import make_robot from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot