Move pretrained config

This commit is contained in:
Simon Alibert
2025-03-08 10:32:21 +01:00
parent c91a53be11
commit 6909b62a21
15 changed files with 19 additions and 14 deletions

View File

@@ -24,7 +24,7 @@ from lerobot.common.datasets.lerobot_dataset import (
MultiLeRobotDataset,
)
from lerobot.common.datasets.transforms import ImageTransforms
from lerobot.configs.policies import PreTrainedConfig
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
IMAGENET_STATS = {

View File

@@ -16,9 +16,10 @@
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("act")
@dataclass

View File

@@ -18,9 +18,10 @@ from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("diffusion")
@dataclass

View File

@@ -23,12 +23,12 @@ from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType

View File

@@ -18,9 +18,10 @@ from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("pi0")
@dataclass

View File

@@ -15,8 +15,8 @@
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.factory import make_policy
from lerobot.configs.policies import PreTrainedConfig
torch.backends.cudnn.benchmark = True

View File

@@ -19,8 +19,8 @@ from pathlib import Path
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.factory import make_policy
from lerobot.configs.policies import PreTrainedConfig
def display(tensor: torch.Tensor):

View File

@@ -26,8 +26,8 @@ from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.configs.policies import PreTrainedConfig
T = TypeVar("T", bound="PreTrainedPolicy")

View File

@@ -17,9 +17,10 @@
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("tdmpc")
@dataclass

View File

@@ -20,9 +20,10 @@ from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("vqbet")
@dataclass

View File

@@ -17,9 +17,9 @@ from pathlib import Path
import draccus
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
@dataclass

View File

@@ -18,9 +18,9 @@ from dataclasses import dataclass, field
from pathlib import Path
from lerobot.common import envs, policies # noqa: F401
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.configs import parser
from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
@dataclass

View File

@@ -24,10 +24,10 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot.common import envs
from lerobot.common.optim import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
TRAIN_CONFIG_NAME = "train_config.json"

View File

@@ -42,6 +42,7 @@ from unittest.mock import patch
import pytest
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.config import PreTrainedConfig
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
@@ -49,7 +50,6 @@ from lerobot.common.robot_devices.control_configs import (
ReplayControlConfig,
TeleoperateControlConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.test_robots import make_robot
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot