Move pretrained config
This commit is contained in:
@@ -24,7 +24,7 @@ from lerobot.common.datasets.lerobot_dataset import (
|
|||||||
MultiLeRobotDataset,
|
MultiLeRobotDataset,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.transforms import ImageTransforms
|
from lerobot.common.datasets.transforms import ImageTransforms
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.common.policies.config import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
|
||||||
IMAGENET_STATS = {
|
IMAGENET_STATS = {
|
||||||
|
|||||||
@@ -16,9 +16,10 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import AdamWConfig
|
from lerobot.common.optim.optimizers import AdamWConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
from ..config import PreTrainedConfig
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("act")
|
@PreTrainedConfig.register_subclass("act")
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -18,9 +18,10 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
from lerobot.common.optim.optimizers import AdamConfig
|
from lerobot.common.optim.optimizers import AdamConfig
|
||||||
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
|
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
from ..config import PreTrainedConfig
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("diffusion")
|
@PreTrainedConfig.register_subclass("diffusion")
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -23,12 +23,12 @@ from lerobot.common.datasets.utils import dataset_to_policy_features
|
|||||||
from lerobot.common.envs.configs import EnvConfig
|
from lerobot.common.envs.configs import EnvConfig
|
||||||
from lerobot.common.envs.utils import env_to_policy_features
|
from lerobot.common.envs.utils import env_to_policy_features
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
|
from lerobot.common.policies.config import PreTrainedConfig
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
from lerobot.configs.types import FeatureType
|
from lerobot.configs.types import FeatureType
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,9 +18,10 @@ from lerobot.common.optim.optimizers import AdamWConfig
|
|||||||
from lerobot.common.optim.schedulers import (
|
from lerobot.common.optim.schedulers import (
|
||||||
CosineDecayWithWarmupSchedulerConfig,
|
CosineDecayWithWarmupSchedulerConfig,
|
||||||
)
|
)
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
|
||||||
|
from ..config import PreTrainedConfig
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("pi0")
|
@PreTrainedConfig.register_subclass("pi0")
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -15,8 +15,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.policies.config import PreTrainedConfig
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
from lerobot.common.policies.config import PreTrainedConfig
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
def display(tensor: torch.Tensor):
|
def display(tensor: torch.Tensor):
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ from safetensors.torch import load_model as load_model_as_safetensor
|
|||||||
from safetensors.torch import save_model as save_model_as_safetensor
|
from safetensors.torch import save_model as save_model_as_safetensor
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from lerobot.common.policies.config import PreTrainedConfig
|
||||||
from lerobot.common.utils.hub import HubMixin
|
from lerobot.common.utils.hub import HubMixin
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||||
|
|
||||||
|
|||||||
@@ -17,9 +17,10 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import AdamConfig
|
from lerobot.common.optim.optimizers import AdamConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
from ..config import PreTrainedConfig
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("tdmpc")
|
@PreTrainedConfig.register_subclass("tdmpc")
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -20,9 +20,10 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
from lerobot.common.optim.optimizers import AdamConfig
|
from lerobot.common.optim.optimizers import AdamConfig
|
||||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
from ..config import PreTrainedConfig
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("vqbet")
|
@PreTrainedConfig.register_subclass("vqbet")
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ from pathlib import Path
|
|||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
|
from lerobot.common.policies.config import PreTrainedConfig
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -18,9 +18,9 @@ from dataclasses import dataclass, field
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lerobot.common import envs, policies # noqa: F401
|
from lerobot.common import envs, policies # noqa: F401
|
||||||
|
from lerobot.common.policies.config import PreTrainedConfig
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import EvalConfig
|
from lerobot.configs.default import EvalConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -24,10 +24,10 @@ from huggingface_hub.errors import HfHubHTTPError
|
|||||||
from lerobot.common import envs
|
from lerobot.common import envs
|
||||||
from lerobot.common.optim import OptimizerConfig
|
from lerobot.common.optim import OptimizerConfig
|
||||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||||
|
from lerobot.common.policies.config import PreTrainedConfig
|
||||||
from lerobot.common.utils.hub import HubMixin
|
from lerobot.common.utils.hub import HubMixin
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
|
|
||||||
TRAIN_CONFIG_NAME = "train_config.json"
|
TRAIN_CONFIG_NAME = "train_config.json"
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
|
from lerobot.common.policies.config import PreTrainedConfig
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.robot_devices.control_configs import (
|
from lerobot.common.robot_devices.control_configs import (
|
||||||
CalibrateControlConfig,
|
CalibrateControlConfig,
|
||||||
@@ -49,7 +50,6 @@ from lerobot.common.robot_devices.control_configs import (
|
|||||||
ReplayControlConfig,
|
ReplayControlConfig,
|
||||||
TeleoperateControlConfig,
|
TeleoperateControlConfig,
|
||||||
)
|
)
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||||
from tests.test_robots import make_robot
|
from tests.test_robots import make_robot
|
||||||
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||||
|
|||||||
Reference in New Issue
Block a user