Move pretrained config

This commit is contained in:
Simon Alibert
2025-03-08 10:32:21 +01:00
parent c91a53be11
commit 6909b62a21
15 changed files with 19 additions and 14 deletions

View File

@@ -24,7 +24,7 @@ from lerobot.common.datasets.lerobot_dataset import (
MultiLeRobotDataset, MultiLeRobotDataset,
) )
from lerobot.common.datasets.transforms import ImageTransforms from lerobot.common.datasets.transforms import ImageTransforms
from lerobot.configs.policies import PreTrainedConfig from lerobot.common.policies.config import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
IMAGENET_STATS = { IMAGENET_STATS = {

View File

@@ -16,9 +16,10 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("act") @PreTrainedConfig.register_subclass("act")
@dataclass @dataclass

View File

@@ -18,9 +18,10 @@ from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("diffusion") @PreTrainedConfig.register_subclass("diffusion")
@dataclass @dataclass

View File

@@ -23,12 +23,12 @@ from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType from lerobot.configs.types import FeatureType

View File

@@ -18,9 +18,10 @@ from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import ( from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig, CosineDecayWithWarmupSchedulerConfig,
) )
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("pi0") @PreTrainedConfig.register_subclass("pi0")
@dataclass @dataclass

View File

@@ -15,8 +15,8 @@
import torch import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.configs.policies import PreTrainedConfig
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True

View File

@@ -19,8 +19,8 @@ from pathlib import Path
import torch import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.configs.policies import PreTrainedConfig
def display(tensor: torch.Tensor): def display(tensor: torch.Tensor):

View File

@@ -26,8 +26,8 @@ from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor from safetensors.torch import save_model as save_model_as_safetensor
from torch import Tensor, nn from torch import Tensor, nn
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.utils.hub import HubMixin from lerobot.common.utils.hub import HubMixin
from lerobot.configs.policies import PreTrainedConfig
T = TypeVar("T", bound="PreTrainedPolicy") T = TypeVar("T", bound="PreTrainedPolicy")

View File

@@ -17,9 +17,10 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig from lerobot.common.optim.optimizers import AdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("tdmpc") @PreTrainedConfig.register_subclass("tdmpc")
@dataclass @dataclass

View File

@@ -20,9 +20,10 @@ from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("vqbet") @PreTrainedConfig.register_subclass("vqbet")
@dataclass @dataclass

View File

@@ -17,9 +17,9 @@ from pathlib import Path
import draccus import draccus
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
@dataclass @dataclass

View File

@@ -18,9 +18,9 @@ from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from lerobot.common import envs, policies # noqa: F401 from lerobot.common import envs, policies # noqa: F401
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.default import EvalConfig from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
@dataclass @dataclass

View File

@@ -24,10 +24,10 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot.common import envs from lerobot.common import envs
from lerobot.common.optim import OptimizerConfig from lerobot.common.optim import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.utils.hub import HubMixin from lerobot.common.utils.hub import HubMixin
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
TRAIN_CONFIG_NAME = "train_config.json" TRAIN_CONFIG_NAME = "train_config.json"

View File

@@ -42,6 +42,7 @@ from unittest.mock import patch
import pytest import pytest
from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import ( from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig, CalibrateControlConfig,
@@ -49,7 +50,6 @@ from lerobot.common.robot_devices.control_configs import (
ReplayControlConfig, ReplayControlConfig,
TeleoperateControlConfig, TeleoperateControlConfig,
) )
from lerobot.configs.policies import PreTrainedConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.test_robots import make_robot from tests.test_robots import make_robot
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot