Compare commits

...

9 Commits

Author SHA1 Message Date
Simon Alibert
f26e233995 Remove register 2025-03-10 13:30:14 +01:00
Simon Alibert
57e0bdb491 Move make_policy_config 2025-03-08 12:42:43 +01:00
Simon Alibert
00698305a3 Simplify optim imports 2025-03-08 12:27:20 +01:00
Simon Alibert
ec2990518b Move transformers to default dependencies 2025-03-08 11:48:17 +01:00
Simon Alibert
80a6cee699 Fix typing 2025-03-08 11:09:21 +01:00
Simon Alibert
80387285bb Cleanup policies imports 2025-03-08 11:08:40 +01:00
Simon Alibert
6909b62a21 Move pretrained config 2025-03-08 10:32:21 +01:00
Simon Alibert
c91a53be11 Add register mechanism 2025-03-08 10:24:39 +01:00
Simon Alibert
727638dda5 Add inits 2025-03-08 10:23:55 +01:00
31 changed files with 84 additions and 80 deletions

View File

@@ -24,7 +24,7 @@ from lerobot.common.datasets.lerobot_dataset import (
MultiLeRobotDataset,
)
from lerobot.common.datasets.transforms import ImageTransforms
from lerobot.configs.policies import PreTrainedConfig
from lerobot.common.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
IMAGENET_STATS = {

View File

@@ -12,4 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .optimizers import OptimizerConfig as OptimizerConfig
from .optimizers import OptimizerConfig
from .schedulers import LRSchedulerConfig
__all__ = ["OptimizerConfig", "LRSchedulerConfig"]

View File

@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from . import act, diffusion, pi0, tdmpc, vqbet
from .config import PreTrainedConfig
from .factory import make_policy
from .pretrained import PreTrainedPolicy
__all__ = ["act", "diffusion", "pi0", "tdmpc", "vqbet", "make_policy", "PreTrainedConfig", "PreTrainedPolicy"]

View File

@@ -0,0 +1,4 @@
from .configuration_act import ACTConfig
from .modeling_act import ACTPolicy
__all__ = ["ACTConfig", "ACTPolicy"]

View File

@@ -16,9 +16,10 @@
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("act")
@dataclass

View File

@@ -23,8 +23,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.common.optim.optimizers import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.optim import LRSchedulerConfig, OptimizerConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature

View File

@@ -0,0 +1,4 @@
from .configuration_diffusion import DiffusionConfig
from .modeling_diffusion import DiffusionPolicy
__all__ = ["DiffusionConfig", "DiffusionPolicy"]

View File

@@ -18,9 +18,10 @@ from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("diffusion")
@dataclass

View File

@@ -22,57 +22,38 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
from .config import PreTrainedConfig
from .pretrained import PreTrainedPolicy
def get_policy_class(name: str) -> PreTrainedPolicy:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc":
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy
elif name == "diffusion":
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy
elif name == "act":
from lerobot.common.policies.act.modeling_act import ACTPolicy
if name == "act":
from .act import ACTPolicy
return ACTPolicy
elif name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
elif name == "diffusion":
from .diffusion import DiffusionPolicy
return VQBeTPolicy
return DiffusionPolicy
elif name == "pi0":
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
from .pi0 import PI0Policy
return PI0Policy
elif name == "tdmpc":
from .tdmpc import TDMPCPolicy
return TDMPCPolicy
elif name == "vqbet":
from .vqbet import VQBeTPolicy
return VQBeTPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
if policy_type == "tdmpc":
return TDMPCConfig(**kwargs)
elif policy_type == "diffusion":
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,

View File

@@ -0,0 +1,4 @@
from .configuration_pi0 import PI0Config
from .modeling_pi0 import PI0Policy
__all__ = ["PI0Config", "PI0Policy"]

View File

@@ -18,9 +18,10 @@ from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("pi0")
@dataclass

View File

@@ -15,8 +15,7 @@
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.configs.policies import PreTrainedConfig
from lerobot.common.policies import PreTrainedConfig, make_policy
torch.backends.cudnn.benchmark = True

View File

@@ -19,8 +19,7 @@ from pathlib import Path
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.policies.factory import make_policy
from lerobot.configs.policies import PreTrainedConfig
from lerobot.common.policies import PreTrainedConfig, make_policy
def display(tensor: torch.Tensor):

View File

@@ -22,11 +22,6 @@
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Install pi0 extra dependencies:
```bash
pip install -e ".[pi0]"
```
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
```bash
python lerobot/scripts/train.py \

View File

@@ -27,7 +27,8 @@ from safetensors.torch import save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.common.utils.hub import HubMixin
from lerobot.configs.policies import PreTrainedConfig
from .config import PreTrainedConfig
T = TypeVar("T", bound="PreTrainedPolicy")

View File

@@ -0,0 +1,4 @@
from .configuration_tdmpc import TDMPCConfig
from .modeling_tdmpc import TDMPCPolicy
__all__ = ["TDMPCConfig", "TDMPCPolicy"]

View File

@@ -17,9 +17,10 @@
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("tdmpc")
@dataclass

View File

@@ -0,0 +1,4 @@
from .configuration_vqbet import VQBeTConfig
from .modeling_vqbet import VQBeTPolicy
__all__ = ["VQBeTConfig", "VQBeTPolicy"]

View File

@@ -20,9 +20,10 @@ from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from ..config import PreTrainedConfig
@PreTrainedConfig.register_subclass("vqbet")
@dataclass

View File

@@ -17,9 +17,9 @@ from pathlib import Path
import draccus
from lerobot.common.policies import PreTrainedConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
@dataclass

View File

@@ -18,9 +18,9 @@ from dataclasses import dataclass, field
from pathlib import Path
from lerobot.common import envs, policies # noqa: F401
from lerobot.common.policies import PreTrainedConfig
from lerobot.configs import parser
from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
@dataclass

View File

@@ -22,12 +22,11 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HfHubHTTPError
from lerobot.common import envs
from lerobot.common.optim import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.optim import LRSchedulerConfig, OptimizerConfig
from lerobot.common.policies import PreTrainedConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
TRAIN_CONFIG_NAME = "train_config.json"

View File

@@ -141,7 +141,7 @@ from pprint import pformat
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
ControlPipelineConfig,

View File

@@ -67,8 +67,7 @@ from tqdm import trange
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies import PreTrainedPolicy, make_policy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.random_utils import set_seed

View File

@@ -29,8 +29,7 @@ from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies import PreTrainedPolicy, make_policy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.common.utils.random_utils import set_seed

View File

@@ -71,6 +71,7 @@ dependencies = [
"termcolor>=2.4.0",
"torch>=2.2.1",
"torchvision>=0.21.0",
"transformers>=4.48.0",
"wandb>=0.16.3",
"zarr>=2.17.0",
]
@@ -84,7 +85,6 @@ dora = [
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
pi0 = ["transformers>=4.48.0"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
stretch = [
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",

View File

@@ -21,10 +21,11 @@ from safetensors.torch import save_file
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy, make_policy_config
from lerobot.common.policies import make_policy
from lerobot.common.utils.random_utils import set_seed
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.utils import make_policy_config
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):

View File

@@ -41,15 +41,14 @@ from unittest.mock import patch
import pytest
from lerobot.common.policies import PreTrainedConfig, make_policy
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
RecordControlConfig,
ReplayControlConfig,
TeleoperateControlConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.test_robots import make_robot
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot

View File

@@ -40,12 +40,11 @@ from lerobot.common.datasets.utils import (
unflatten_dict,
)
from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.utils import require_x86_64_kernel
from tests.utils import make_policy_config, require_x86_64_kernel
@pytest.fixture

View File

@@ -28,11 +28,11 @@ from lerobot.common.datasets.utils import cycle, dataset_to_policy_features
from lerobot.common.envs.factory import make_env, make_env_config
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.act import ACTConfig
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.common.policies.factory import (
get_policy_class,
make_policy,
make_policy_config,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
@@ -41,7 +41,7 @@ from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
from tests.utils import DEVICE, make_policy_config, require_cpu, require_env, require_x86_64_kernel
@pytest.fixture
@@ -208,11 +208,10 @@ def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
"""
cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
policy=ACTConfig(optimizer_lr=0.01, optimizer_lr_backbone=0.001),
)
cfg.validate() # Needed for auto-setting some parameters

View File

@@ -23,6 +23,7 @@ import pytest
import torch
from lerobot import available_cameras, available_motors, available_robots
from lerobot.common.policies import PreTrainedConfig
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
from lerobot.common.robot_devices.motors.utils import MotorsBus
@@ -329,3 +330,8 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
else:
raise ValueError(f"The motor type '{motor_type}' is not valid.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
policy_cfg_cls = PreTrainedConfig.get_choice_class(policy_type)
return policy_cfg_cls(**kwargs)