forked from tangger/lerobot
Compare commits
9 Commits
main
...
user/alibe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f26e233995 | ||
|
|
57e0bdb491 | ||
|
|
00698305a3 | ||
|
|
ec2990518b | ||
|
|
80a6cee699 | ||
|
|
80387285bb | ||
|
|
6909b62a21 | ||
|
|
c91a53be11 | ||
|
|
727638dda5 |
@@ -24,7 +24,7 @@ from lerobot.common.datasets.lerobot_dataset import (
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.transforms import ImageTransforms
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.common.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
IMAGENET_STATS = {
|
||||
|
||||
@@ -12,4 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .optimizers import OptimizerConfig as OptimizerConfig
|
||||
from .optimizers import OptimizerConfig
|
||||
from .schedulers import LRSchedulerConfig
|
||||
|
||||
__all__ = ["OptimizerConfig", "LRSchedulerConfig"]
|
||||
|
||||
@@ -12,8 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from . import act, diffusion, pi0, tdmpc, vqbet
|
||||
from .config import PreTrainedConfig
|
||||
from .factory import make_policy
|
||||
from .pretrained import PreTrainedPolicy
|
||||
|
||||
__all__ = ["act", "diffusion", "pi0", "tdmpc", "vqbet", "make_policy", "PreTrainedConfig", "PreTrainedPolicy"]
|
||||
|
||||
4
lerobot/common/policies/act/__init__.py
Normal file
4
lerobot/common/policies/act/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .configuration_act import ACTConfig
|
||||
from .modeling_act import ACTPolicy
|
||||
|
||||
__all__ = ["ACTConfig", "ACTPolicy"]
|
||||
@@ -16,9 +16,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
from ..config import PreTrainedConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("act")
|
||||
@dataclass
|
||||
|
||||
@@ -23,8 +23,7 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.common.optim.optimizers import OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.common.optim import LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
4
lerobot/common/policies/diffusion/__init__.py
Normal file
4
lerobot/common/policies/diffusion/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .configuration_diffusion import DiffusionConfig
|
||||
from .modeling_diffusion import DiffusionPolicy
|
||||
|
||||
__all__ = ["DiffusionConfig", "DiffusionPolicy"]
|
||||
@@ -18,9 +18,10 @@ from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
from ..config import PreTrainedConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("diffusion")
|
||||
@dataclass
|
||||
|
||||
@@ -22,57 +22,38 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
from .config import PreTrainedConfig
|
||||
from .pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
return DiffusionPolicy
|
||||
elif name == "act":
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
if name == "act":
|
||||
from .act import ACTPolicy
|
||||
|
||||
return ACTPolicy
|
||||
elif name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
elif name == "diffusion":
|
||||
from .diffusion import DiffusionPolicy
|
||||
|
||||
return VQBeTPolicy
|
||||
return DiffusionPolicy
|
||||
elif name == "pi0":
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from .pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
elif name == "tdmpc":
|
||||
from .tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy
|
||||
elif name == "vqbet":
|
||||
from .vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
if policy_type == "tdmpc":
|
||||
return TDMPCConfig(**kwargs)
|
||||
elif policy_type == "diffusion":
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "act":
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
|
||||
4
lerobot/common/policies/pi0/__init__.py
Normal file
4
lerobot/common/policies/pi0/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .configuration_pi0 import PI0Config
|
||||
from .modeling_pi0 import PI0Policy
|
||||
|
||||
__all__ = ["PI0Config", "PI0Policy"]
|
||||
@@ -18,9 +18,10 @@ from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
from ..config import PreTrainedConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@dataclass
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.common.policies import PreTrainedConfig, make_policy
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
@@ -19,8 +19,7 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.common.policies import PreTrainedConfig, make_policy
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
|
||||
@@ -22,11 +22,6 @@
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
|
||||
Install pi0 extra dependencies:
|
||||
```bash
|
||||
pip install -e ".[pi0]"
|
||||
```
|
||||
|
||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
|
||||
@@ -27,7 +27,8 @@ from safetensors.torch import save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
from .config import PreTrainedConfig
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||
|
||||
|
||||
4
lerobot/common/policies/tdmpc/__init__.py
Normal file
4
lerobot/common/policies/tdmpc/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .configuration_tdmpc import TDMPCConfig
|
||||
from .modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
__all__ = ["TDMPCConfig", "TDMPCPolicy"]
|
||||
@@ -17,9 +17,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
from ..config import PreTrainedConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("tdmpc")
|
||||
@dataclass
|
||||
|
||||
4
lerobot/common/policies/vqbet/__init__.py
Normal file
4
lerobot/common/policies/vqbet/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .configuration_vqbet import VQBeTConfig
|
||||
from .modeling_vqbet import VQBeTPolicy
|
||||
|
||||
__all__ = ["VQBeTConfig", "VQBeTPolicy"]
|
||||
@@ -20,9 +20,10 @@ from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
from ..config import PreTrainedConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vqbet")
|
||||
@dataclass
|
||||
|
||||
@@ -17,9 +17,9 @@ from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.policies import PreTrainedConfig
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -18,9 +18,9 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common import envs, policies # noqa: F401
|
||||
from lerobot.common.policies import PreTrainedConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import EvalConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -22,12 +22,11 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.common import envs
|
||||
from lerobot.common.optim import OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.common.optim import LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.common.policies import PreTrainedConfig
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ from pprint import pformat
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
ControlPipelineConfig,
|
||||
|
||||
@@ -67,8 +67,7 @@ from tqdm import trange
|
||||
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies import PreTrainedPolicy, make_policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
|
||||
@@ -29,8 +29,7 @@ from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies import PreTrainedPolicy, make_policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
|
||||
@@ -71,6 +71,7 @@ dependencies = [
|
||||
"termcolor>=2.4.0",
|
||||
"torch>=2.2.1",
|
||||
"torchvision>=0.21.0",
|
||||
"transformers>=4.48.0",
|
||||
"wandb>=0.16.3",
|
||||
"zarr>=2.17.0",
|
||||
]
|
||||
@@ -84,7 +85,6 @@ dora = [
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
stretch = [
|
||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||
|
||||
@@ -21,10 +21,11 @@ from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.common.policies import make_policy
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.utils import make_policy_config
|
||||
|
||||
|
||||
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
|
||||
@@ -41,15 +41,14 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.policies import PreTrainedConfig, make_policy
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
RecordControlConfig,
|
||||
ReplayControlConfig,
|
||||
TeleoperateControlConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||
from tests.test_robots import make_robot
|
||||
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
|
||||
@@ -40,12 +40,11 @@ from lerobot.common.datasets.utils import (
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.utils import require_x86_64_kernel
|
||||
from tests.utils import make_policy_config, require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -28,11 +28,11 @@ from lerobot.common.datasets.utils import cycle, dataset_to_policy_features
|
||||
from lerobot.common.envs.factory import make_env, make_env_config
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.act import ACTConfig
|
||||
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
|
||||
from lerobot.common.policies.factory import (
|
||||
get_policy_class,
|
||||
make_policy,
|
||||
make_policy_config,
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
@@ -41,7 +41,7 @@ from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
from tests.utils import DEVICE, make_policy_config, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -208,11 +208,10 @@ def test_act_backbone_lr():
|
||||
"""
|
||||
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
|
||||
"""
|
||||
|
||||
cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||
policy=ACTConfig(optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||
)
|
||||
cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.common.policies import PreTrainedConfig
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
@@ -329,3 +330,8 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
policy_cfg_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
return policy_cfg_cls(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user