Move pretrained config
This commit is contained in:
@@ -24,7 +24,7 @@ from lerobot.common.datasets.lerobot_dataset import (
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.transforms import ImageTransforms
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
IMAGENET_STATS = {
|
||||
|
||||
@@ -16,9 +16,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
from ..config import PreTrainedConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("act")
|
||||
@dataclass
|
||||
|
||||
@@ -18,9 +18,10 @@ from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
from ..config import PreTrainedConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("diffusion")
|
||||
@dataclass
|
||||
|
||||
@@ -23,12 +23,12 @@ from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
|
||||
|
||||
@@ -18,9 +18,10 @@ from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
from ..config import PreTrainedConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@dataclass
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
|
||||
@@ -26,8 +26,8 @@ from safetensors.torch import load_model as load_model_as_safetensor
|
||||
from safetensors.torch import save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||
|
||||
|
||||
@@ -17,9 +17,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
from ..config import PreTrainedConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("tdmpc")
|
||||
@dataclass
|
||||
|
||||
@@ -20,9 +20,10 @@ from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
from ..config import PreTrainedConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vqbet")
|
||||
@dataclass
|
||||
|
||||
@@ -17,9 +17,9 @@ from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -18,9 +18,9 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common import envs, policies # noqa: F401
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import EvalConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -24,10 +24,10 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
from lerobot.common import envs
|
||||
from lerobot.common.optim import OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
@@ -49,7 +50,6 @@ from lerobot.common.robot_devices.control_configs import (
|
||||
ReplayControlConfig,
|
||||
TeleoperateControlConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||
from tests.test_robots import make_robot
|
||||
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
|
||||
Reference in New Issue
Block a user