Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -0,0 +1,6 @@
# keys
OBS_ENV = "observation.environment_state"
OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"

View File

@@ -14,103 +14,102 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pprint import pformat
import torch
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.common.datasets.transforms import ImageTransforms
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
}
def resolve_delta_timestamps(cfg):
"""Resolves delta_timestamps config key (in-place) by using `eval`.
def resolve_delta_timestamps(
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
) -> dict[str, list] | None:
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by
the data type of its values).
"""
delta_timestamps = cfg.training.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
# TODO(rcadene, alexander-soare): remove `eval` to avoid exploit
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
"""
Args:
cfg: A Hydra config as per the LeRobot config scheme.
split: Select the data subset used to create an instance of LeRobotDataset.
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
slicer in the hugging face datasets:
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
delta_timestamps against.
Returns:
The LeRobotDataset.
dict[str, list] | None: A dictionary of delta_timestamps, e.g.:
{
"observation.state": [-0.04, -0.02, 0]
"observation.action": [-0.02, 0, 0.02]
}
returns `None` if the the resulting dict is empty.
"""
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
raise ValueError(
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
"strings to load multiple datasets."
)
delta_timestamps = {}
for key in ds_meta.features:
if key == "next.reward" and cfg.reward_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == "action" and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
if cfg.env.name != "dora":
if isinstance(cfg.dataset_repo_id, str):
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
else:
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
if len(delta_timestamps) == 0:
delta_timestamps = None
for dataset_repo_id in dataset_repo_ids:
if cfg.env.name not in dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
)
return delta_timestamps
resolve_delta_timestamps(cfg)
image_transforms = None
if cfg.training.image_transforms.enable:
cfg_tf = cfg.training.image_transforms
image_transforms = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
contrast_weight=cfg_tf.contrast.weight,
contrast_min_max=cfg_tf.contrast.min_max,
saturation_weight=cfg_tf.saturation.weight,
saturation_min_max=cfg_tf.saturation.min_max,
hue_weight=cfg_tf.hue.weight,
hue_min_max=cfg_tf.hue.min_max,
sharpness_weight=cfg_tf.sharpness.weight,
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
)
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
if isinstance(cfg.dataset_repo_id, str):
# TODO (aliberts): add 'episodes' arg from config after removing hydra
Args:
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
Raises:
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
Returns:
LeRobotDataset | MultiLeRobotDataset
"""
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)
if isinstance(cfg.dataset.repo_id, str):
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset_repo_id,
delta_timestamps=cfg.training.get("delta_timestamps"),
cfg.dataset.repo_id,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.video_backend,
video_backend=cfg.dataset.video_backend,
local_files_only=cfg.dataset.local_files_only,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id,
delta_timestamps=cfg.training.get("delta_timestamps"),
cfg.dataset.repo_id,
# TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.video_backend,
video_backend=cfg.dataset.video_backend,
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index , indent=2)}"
)
if cfg.get("override_dataset_stats"):
for key, stats_dict in cfg.override_dataset_stats.items():
for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset

View File

@@ -840,7 +840,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def stop_image_writer(self) -> None:
"""
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
"""
if self.image_writer is not None:
self.image_writer.stop()

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from dataclasses import dataclass, field
from typing import Any, Callable, Sequence
import torch
@@ -65,6 +66,8 @@ class RandomSubsetApply(Transform):
self.n_subset = n_subset
self.random_order = random_order
self.selected_transforms = None
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
@@ -72,9 +75,9 @@ class RandomSubsetApply(Transform):
if not self.random_order:
selected_indices = selected_indices.sort().values
selected_transforms = [self.transforms[i] for i in selected_indices]
self.selected_transforms = [self.transforms[i] for i in selected_indices]
for transform in selected_transforms:
for transform in self.selected_transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
@@ -138,61 +141,109 @@ class SharpnessJitter(Transform):
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def get_image_transforms(
brightness_weight: float = 1.0,
brightness_min_max: tuple[float, float] | None = None,
contrast_weight: float = 1.0,
contrast_min_max: tuple[float, float] | None = None,
saturation_weight: float = 1.0,
saturation_min_max: tuple[float, float] | None = None,
hue_weight: float = 1.0,
hue_min_max: tuple[float, float] | None = None,
sharpness_weight: float = 1.0,
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
):
def check_value(name, weight, min_max):
if min_max is not None:
if len(min_max) != 2:
raise ValueError(
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
)
if weight < 0.0:
raise ValueError(
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
)
@dataclass
class ImageTransformConfig:
"""
For each transform, the following parameters are available:
weight: This represents the multinomial probability (with no replacement)
used for sampling the transform. If the sum of the weights is not 1,
they will be normalized.
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
custom transform defined here.
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
(following uniform distribution) when it's applied.
"""
check_value("brightness", brightness_weight, brightness_min_max)
check_value("contrast", contrast_weight, contrast_min_max)
check_value("saturation", saturation_weight, saturation_min_max)
check_value("hue", hue_weight, hue_min_max)
check_value("sharpness", sharpness_weight, sharpness_min_max)
weight: float = 1.0
type: str = "Identity"
kwargs: dict[str, Any] = field(default_factory=dict)
weights = []
transforms = []
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
if contrast_min_max is not None and contrast_weight > 0.0:
weights.append(contrast_weight)
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
if saturation_min_max is not None and saturation_weight > 0.0:
weights.append(saturation_weight)
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
if hue_min_max is not None and hue_weight > 0.0:
weights.append(hue_weight)
transforms.append(v2.ColorJitter(hue=hue_min_max))
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
n_subset = len(transforms)
if max_num_transforms is not None:
n_subset = min(n_subset, max_num_transforms)
@dataclass
class ImageTransformsConfig:
"""
These transforms are all using standard torchvision.transforms.v2
You can find out how these transformations affect images here:
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
We use a custom RandomSubsetApply container to sample them.
"""
if n_subset == 0:
return v2.Identity()
# Set this flag to `true` to enable transforms during training
enable: bool = False
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [1, number_of_available_transforms].
max_num_transforms: int = 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: bool = False
tfs: dict[str, ImageTransformConfig] = field(
default_factory=lambda: {
"brightness": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.8, 1.2)},
),
"contrast": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.8, 1.2)},
),
"saturation": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 1.5)},
),
"hue": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (-0.05, 0.05)},
),
"sharpness": ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
}
)
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
else:
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
raise ValueError(f"Transform '{cfg.type}' is not valid.")
class ImageTransforms(Transform):
"""A class to compose image transforms based on configuration."""
def __init__(self, cfg: ImageTransformsConfig) -> None:
super().__init__()
self._cfg = cfg
self.weights = []
self.transforms = {}
for tf_name, tf_cfg in cfg.tfs.items():
if tf_cfg.weight <= 0.0:
continue
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
self.weights.append(tf_cfg.weight)
n_subset = min(len(self.transforms), cfg.max_num_transforms)
if n_subset == 0 or not cfg.enable:
self.tf = v2.Identity()
else:
self.tf = RandomSubsetApply(
transforms=list(self.transforms.values()),
p=self.weights,
n_subset=n_subset,
random_order=cfg.random_order,
)
def forward(self, *inputs: Any) -> Any:
return self.tf(*inputs)

View File

@@ -35,6 +35,7 @@ from PIL import Image as PILImage
from torchvision import transforms
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
@@ -98,6 +99,18 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
split_keys = flattened_key.split(sep)
getter = obj[split_keys[0]]
if len(split_keys) == 1:
return getter
for key in split_keys[1:]:
getter = getter[key]
return getter
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
return unflatten_dict(serialized_dict)
@@ -289,6 +302,37 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = {}
for key, ft in features.items():
shape = ft["shape"]
if ft["dtype"] in ["image", "video"]:
type = FeatureType.VISUAL
if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state":
type = FeatureType.ENV
elif key.startswith("observation"):
type = FeatureType.STATE
elif key == "action":
type = FeatureType.ACTION
else:
continue
policy_features[key] = PolicyFeature(
type=type,
shape=shape,
)
return policy_features
def create_empty_dataset_info(
codebase_version: str,
fps: int,
@@ -436,7 +480,7 @@ def check_delta_timestamps(
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
delta_indices[key] = [round(d * fps) for d in delta_ts]
return delta_indices

View File

@@ -26,13 +26,13 @@ from pathlib import Path
from textwrap import dedent
from lerobot import available_datasets
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
LOCAL_DIR = Path("data/")
ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml")
ALOHA_MOBILE_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG),
"robot_config": AlohaRobotConfig(),
"license": "mit",
"url": "https://mobile-aloha.github.io/",
"paper": "https://arxiv.org/abs/2401.02117",
@@ -45,7 +45,7 @@ ALOHA_MOBILE_INFO = {
}""").lstrip(),
}
ALOHA_STATIC_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG),
"robot_config": AlohaRobotConfig(),
"license": "mit",
"url": "https://tonyzhaozh.github.io/aloha/",
"paper": "https://arxiv.org/abs/2304.13705",

View File

@@ -141,7 +141,8 @@ from lerobot.common.datasets.video_utils import (
get_image_pixel_channels,
get_video_info,
)
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_config
V16 = "v1.6"
V20 = "v2.0"
@@ -152,19 +153,18 @@ V1_INFO_PATH = "meta_data/info.json"
V1_STATS_PATH = "meta_data/stats.safetensors"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
robot_cfg = init_hydra_config(config_path, config_overrides)
if robot_cfg["robot_type"] in ["aloha", "koch"]:
def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
if robot_cfg.type in ["aloha", "koch"]:
state_names = [
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor
for arm in robot_cfg["follower_arms"]
for motor in robot_cfg["follower_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
for arm in robot_cfg.follower_arms
for motor in robot_cfg.follower_arms[arm].motors
]
action_names = [
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor
for arm in robot_cfg["leader_arms"]
for motor in robot_cfg["leader_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
for arm in robot_cfg.leader_arms
for motor in robot_cfg.leader_arms[arm].motors
]
# elif robot_cfg["robot_type"] == "stretch3": TODO
else:
@@ -173,7 +173,7 @@ def parse_robot_config(config_path: Path, config_overrides: list[str] | None = N
)
return {
"robot_type": robot_cfg["robot_type"],
"robot_type": robot_cfg.type,
"names": {
"observation.state": state_names,
"observation.effort": state_names,
@@ -203,7 +203,10 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
torch.testing.assert_close(stats_json[key], stats[key])
def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]:
def get_features_from_hf_dataset(
dataset: Dataset, robot_config: RobotConfig | None = None
) -> dict[str, list]:
robot_config = parse_robot_config(robot_config)
features = {}
for key, ft in dataset.features.items():
if isinstance(ft, datasets.Value):
@@ -224,11 +227,11 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
image = dataset[0][key] # Assuming first row
channels = get_image_pixel_channels(image)
shape = (image.height, image.width, channels)
names = ["height", "width", "channel"]
names = ["height", "width", "channels"]
elif ft._type == "VideoFrame":
dtype = "video"
shape = None # Add shape later
names = ["height", "width", "channel"]
names = ["height", "width", "channels"]
features[key] = {
"dtype": dtype,
@@ -436,7 +439,7 @@ def convert_dataset(
single_task: str | None = None,
tasks_path: Path | None = None,
tasks_col: Path | None = None,
robot_config: dict | None = None,
robot_config: RobotConfig | None = None,
test_branch: str | None = None,
**card_kwargs,
):
@@ -532,7 +535,7 @@ def convert_dataset(
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
if robot_config is not None:
robot_type = robot_config["robot_type"]
robot_type = robot_config.type
repo_tags = [robot_type]
else:
robot_type = "unknown"
@@ -621,16 +624,10 @@ def main():
help="The path to a .json file containing one language instruction for each episode_index",
)
parser.add_argument(
"--robot-config",
type=Path,
default=None,
help="Path to the robot's config yaml the dataset during conversion.",
)
parser.add_argument(
"--robot-overrides",
"--robot",
type=str,
nargs="*",
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)",
default=None,
help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
)
parser.add_argument(
"--local-dir",
@@ -655,8 +652,10 @@ def main():
if not args.local_dir:
args.local_dir = Path("/tmp/lerobot_dataset_v2")
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
del args.robot_config, args.robot_overrides
if args.robot is not None:
robot_config = make_robot_config(args.robot)
del args.robot
convert_dataset(**vars(args), robot_config=robot_config)

View File

@@ -0,0 +1 @@
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401

View File

@@ -0,0 +1,142 @@
import abc
from dataclasses import dataclass, field
import draccus
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
from lerobot.configs.types import FeatureType, PolicyFeature
@dataclass
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
task: str | None = None
fps: int = 30
features: dict[str, PolicyFeature] = field(default_factory=dict)
features_map: dict[str, str] = field(default_factory=dict)
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@abc.abstractproperty
def gym_kwargs(self) -> dict:
raise NotImplementedError()
@EnvConfig.register_subclass("aloha")
@dataclass
class AlohaEnv(EnvConfig):
task: str = "AlohaInsertion-v0"
fps: int = 50
episode_length: int = 400
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_ROBOT,
"top": f"{OBS_IMAGE}.top",
"pixels/top": f"{OBS_IMAGES}.top",
}
)
def __post_init__(self):
if self.obs_type == "pixels":
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"max_episode_steps": self.episode_length,
}
@EnvConfig.register_subclass("pusht")
@dataclass
class PushtEnv(EnvConfig):
task: str = "PushT-v0"
fps: int = 10
episode_length: int = 300
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
visualization_width: int = 384
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_ROBOT,
"environment_state": OBS_ENV,
"pixels": OBS_IMAGE,
}
)
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
elif self.obs_type == "environment_state_agent_pos":
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"visualization_width": self.visualization_width,
"visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length,
}
@EnvConfig.register_subclass("xarm")
@dataclass
class XarmEnv(EnvConfig):
task: str = "XarmLift-v0"
fps: int = 15
episode_length: int = 200
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
visualization_width: int = 384
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_ROBOT,
"pixels": OBS_IMAGE,
}
)
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"visualization_width": self.visualization_width,
"visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length,
}

View File

@@ -16,43 +16,54 @@
import importlib
import gymnasium as gym
from omegaconf import DictConfig
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
"""Makes a gym vector environment according to the evaluation config.
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
if env_type == "aloha":
return AlohaEnv(**kwargs)
elif env_type == "pusht":
return PushtEnv(**kwargs)
elif env_type == "xarm":
return XarmEnv(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
"""Makes a gym vector environment according to the config.
Args:
cfg (EnvConfig): the config of the environment to instantiate.
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
use_async_envs (bool, optional): Wether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
False.
Raises:
ValueError: if n_envs < 1
ModuleNotFoundError: If the requested env package is not intalled
Returns:
gym.vector.VectorEnv: The parallelized gym.env instance.
"""
if n_envs is not None and n_envs < 1:
if n_envs < 1:
raise ValueError("`n_envs must be at least 1")
if cfg.env.name == "real_world":
return
package_name = f"gym_{cfg.env.name}"
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
)
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
gym_handle = f"{package_name}/{cfg.env.task}"
gym_kwgs = dict(cfg.env.get("gym", {}))
if cfg.env.get("episode_length"):
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
gym_handle = f"{package_name}/{cfg.task}"
# batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
env = env_cls(
[
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
]
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
)
return env

View File

@@ -18,8 +18,13 @@ import numpy as np
import torch
from torch import Tensor
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.utils.utils import get_channel_first_image_shape
from lerobot.configs.types import FeatureType, PolicyFeature
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
@@ -35,6 +40,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
imgs = {"observation.image": observations["pixels"]}
for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img)
# sanity check that images are channel last
@@ -60,3 +66,23 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return return_observations
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies)
policy_features = {}
for key, ft in env_cfg.features.items():
if ft.type is FeatureType.VISUAL:
if len(ft.shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
shape = get_channel_first_image_shape(ft.shape)
feature = PolicyFeature(type=ft.type, shape=shape)
else:
feature = ft
policy_key = env_cfg.features_map[key]
policy_features[policy_key] = feature
return policy_features

View File

@@ -21,32 +21,39 @@
import logging
import os
import re
from dataclasses import asdict
from glob import glob
from pathlib import Path
import draccus
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.utils import get_global_random_state
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode
PRETRAINED_MODEL = "pretrained_model"
TRAINING_STATE = "training_state.pth"
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.name}",
f"dataset:{cfg.dataset_repo_id}",
f"env:{cfg.env.name}",
f"policy:{cfg.policy.type}",
f"dataset:{cfg.dataset.repo_id}",
f"seed:{cfg.seed}",
]
if cfg.env is not None:
lst.append(f"env:{cfg.env.type}")
return lst if return_list else "-".join(lst)
@@ -68,7 +75,6 @@ class Logger:
The logger creates the following directory structure:
provided_log_dir
├── .hydra # hydra's configuration cache
├── checkpoints
│ ├── specific_checkpoint_name
│ │ ├── pretrained_model # Hugging Face pretrained model directory
@@ -80,28 +86,21 @@ class Logger:
│ └── last # a softlink to the last logged checkpoint
"""
pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth"
pretrained_model_dir_name = PRETRAINED_MODEL
training_state_file_name = TRAINING_STATE
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
"""
Args:
log_dir: The directory to save all logs and training outputs to.
job_name: The WandB job name.
"""
def __init__(self, cfg: TrainPipelineConfig):
self._cfg = cfg
self.log_dir = Path(log_dir)
self.log_dir = cfg.output_dir
self.log_dir.mkdir(parents=True, exist_ok=True)
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
self.job_name = cfg.job_name
self.checkpoints_dir = self.get_checkpoints_dir(self.log_dir)
self.last_checkpoint_dir = self.get_last_checkpoint_dir(self.log_dir)
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(self.log_dir)
# Set up WandB.
self._group = cfg_to_group(cfg)
project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity")
enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project
run_offline = not cfg.wandb.enable or not cfg.wandb.project
if run_offline:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
@@ -115,13 +114,13 @@ class Logger:
wandb.init(
id=wandb_run_id,
project=project,
entity=entity,
name=wandb_job_name,
notes=cfg.get("wandb", {}).get("notes"),
project=cfg.wandb.project,
entity=cfg.wandb.entity,
name=self.job_name,
notes=cfg.wandb.notes,
tags=cfg_to_group(cfg, return_list=True),
dir=log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
dir=self.log_dir,
config=asdict(self._cfg),
# TODO(rcadene): try set to True
save_code=False,
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
@@ -150,17 +149,19 @@ class Logger:
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
def save_model(self, save_dir: Path, policy: PreTrainedPolicy, wandb_artifact_name: str | None = None):
"""Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
Optionally also upload the model to WandB.
"""
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
register_features_types()
policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
# Also save the full config for the env configuration.
self._cfg.save_pretrained(save_dir)
if self._wandb and not self._cfg.wandb.disable_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
@@ -173,18 +174,18 @@ class Logger:
self,
save_dir: Path,
train_step: int,
optimizer: Optimizer,
scheduler: LRScheduler | None,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory.
"""
training_state = {
"step": train_step,
"optimizer": optimizer.state_dict(),
**get_global_random_state(),
}
training_state = {}
training_state["step"] = train_step
training_state.update(get_global_random_state())
if optimizer is not None:
training_state["optimizer"] = optimizer.state_dict()
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name)
@@ -192,10 +193,10 @@ class Logger:
def save_checkpoint(
self,
train_step: int,
policy: Policy,
optimizer: Optimizer,
scheduler: LRScheduler | None,
identifier: str,
policy: PreTrainedPolicy,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
):
"""Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier)
@@ -208,26 +209,11 @@ class Logger:
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step.
"""
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
raise ValueError(
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
)
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"]
relative_target = checkpoint_dir.relative_to(self.last_checkpoint_dir.parent)
self.last_checkpoint_dir.symlink_to(relative_target)
def log_dict(self, d, step, mode="train"):
def log_dict(self, d: dict, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
if self._wandb is not None:
@@ -242,5 +228,13 @@ class Logger:
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
assert self._wandb is not None
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
wandb_video = self._wandb.Video(video_path, fps=self._cfg.env.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
def register_features_types():
draccus.decode.register(FeatureType, lambda x: FeatureType[x])
draccus.encode.register(FeatureType, lambda x: x.name)
draccus.decode.register(NormalizationMode, lambda x: NormalizationMode[x])
draccus.encode.register(NormalizationMode, lambda x: x.name)

View File

@@ -0,0 +1 @@
from .optimizers import OptimizerConfig as OptimizerConfig

View File

@@ -0,0 +1,61 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.logger import TRAINING_STATE
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
from lerobot.configs.train import TrainPipelineConfig
def make_optimizer_and_scheduler(
cfg: TrainPipelineConfig, policy: PreTrainedPolicy
) -> tuple[Optimizer, LRScheduler | None]:
"""Generates the optimizer and scheduler based on configs.
Args:
cfg (TrainPipelineConfig): The training config that contains optimizer and scheduler configs
policy (PreTrainedPolicy): The policy config from which parameters and presets must be taken from.
Returns:
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
"""
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
optimizer = cfg.optimizer.build(params)
lr_scheduler = cfg.scheduler.build(optimizer, cfg.offline.steps) if cfg.scheduler is not None else None
return optimizer, lr_scheduler
def load_training_state(checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Given the checkpoint directory, load the optimizer state, scheduler state, and random state, and
return the global training step.
"""
# TODO(aliberts): use safetensors instead as weights_only=False is unsafe
training_state = torch.load(checkpoint_dir / TRAINING_STATE, weights_only=False)
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
raise ValueError("The checkpoint contains a scheduler state_dict, but no LRScheduler was provided.")
# Small HACK to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"], optimizer, scheduler

View File

@@ -0,0 +1,56 @@
import abc
from dataclasses import asdict, dataclass
import draccus
import torch
@dataclass
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
lr: float
betas: tuple[float, float]
eps: float
weight_decay: float
grad_clip_norm: float
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@classmethod
def default_choice_name(cls) -> str | None:
return "adam"
@abc.abstractmethod
def build(self) -> torch.optim.Optimizer:
raise NotImplementedError
@OptimizerConfig.register_subclass("adam")
@dataclass
class AdamConfig(OptimizerConfig):
lr: float = 1e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.Adam(params, **kwargs)
@OptimizerConfig.register_subclass("adamw")
@dataclass
class AdamWConfig(OptimizerConfig):
lr: float = 1e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 1e-2
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.AdamW(params, **kwargs)

View File

@@ -0,0 +1,56 @@
import abc
import math
from dataclasses import asdict, dataclass
import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
@dataclass
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
num_warmup_steps: int
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@abc.abstractmethod
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
raise NotImplementedError
@LRSchedulerConfig.register_subclass("diffuser")
@dataclass
class DiffuserSchedulerConfig(LRSchedulerConfig):
name: str = "cosine"
num_warmup_steps: int | None = None
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
from diffusers.optimization import get_scheduler
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
return get_scheduler(**kwargs)
@LRSchedulerConfig.register_subclass("vqbet")
@dataclass
class VQBeTSchedulerConfig(LRSchedulerConfig):
num_warmup_steps: int
num_vqvae_training_steps: int
num_cycles: float = 0.5
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step):
if current_step < self.num_vqvae_training_steps:
return float(1)
else:
adjusted_step = current_step - self.num_vqvae_training_steps
if adjusted_step < self.num_warmup_steps:
return float(adjusted_step) / float(max(1, self.num_warmup_steps))
progress = float(adjusted_step - self.num_warmup_steps) / float(
max(1, num_training_steps - self.num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda, -1)

View File

@@ -0,0 +1,4 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig

View File

@@ -15,9 +15,14 @@
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("act")
@dataclass
class ACTConfig:
class ACTConfig(PreTrainedConfig):
"""Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
@@ -90,28 +95,11 @@ class ACTConfig:
chunk_size: int = 100
n_action_steps: int = 100
input_shapes: dict[str, list[int]] = field(
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [14],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.images.top": "mean_std",
"observation.state": "mean_std",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
@@ -144,7 +132,14 @@ class ACTConfig:
dropout: float = 0.1
kl_weight: float = 10.0
# Training preset
optimizer_lr: float = 1e-5
optimizer_weight_decay: float = 1e-4
optimizer_lr_backbone: float = 1e-5
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
@@ -164,8 +159,28 @@ class ACTConfig:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if (
not any(k.startswith("observation.image") for k in self.input_shapes)
and "observation.environment_state" not in self.input_shapes
):
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -29,32 +29,27 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
class ACTPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "act"],
):
class ACTPolicy(PreTrainedPolicy):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
"""
config_class = ACTConfig
name = "act"
def __init__(
self,
config: ACTConfig | None = None,
config: ACTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
@@ -64,30 +59,46 @@ class ACTPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = ACTConfig()
self.config: ACTConfig = config
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset()
def get_optim_params(self) -> dict:
# TODO(aliberts, rcadene): As of now, lr_backbone == lr
# Should we remove this and just `return self.parameters()`?
return [
{
"params": [
p
for n, p in self.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in self.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": self.config.optimizer_lr_backbone,
},
]
def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_coeff is not None:
@@ -106,9 +117,11 @@ class ACTPolicy(
self.eval()
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@@ -134,9 +147,11 @@ class ACTPolicy(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@@ -288,31 +303,30 @@ class ACT(nn.Module):
"""
def __init__(self, config: ACTConfig):
super().__init__()
self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_robot_state = "observation.state" in config.input_shapes
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
self.use_env_state = "observation.environment_state" in config.input_shapes
super().__init__()
self.config = config
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
if self.use_robot_state:
if self.config.robot_state_feature:
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
self.config.robot_state_feature.shape[0], config.dim_model
)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(
config.output_shapes["action"][0], config.dim_model
self.config.action_feature.shape[0],
config.dim_model,
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
if self.use_robot_state:
if self.config.robot_state_feature:
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
@@ -320,7 +334,7 @@ class ACT(nn.Module):
)
# Backbone for image feature extraction.
if self.use_images:
if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights,
@@ -337,27 +351,27 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.use_robot_state:
if self.config.robot_state_feature:
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
self.config.robot_state_feature.shape[0], config.dim_model
)
if self.use_env_state:
if self.config.env_state_feature:
self.encoder_env_state_input_proj = nn.Linear(
config.input_shapes["observation.environment_state"][0], config.dim_model
self.config.env_state_feature.shape[0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
if self.use_images:
if self.config.image_features:
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
n_1d_tokens = 1 # for the latent
if self.use_robot_state:
if self.config.robot_state_feature:
n_1d_tokens += 1
if self.use_env_state:
if self.config.env_state_feature:
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.use_images:
if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
@@ -365,7 +379,7 @@ class ACT(nn.Module):
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
self._reset_parameters()
@@ -380,13 +394,13 @@ class ACT(nn.Module):
`batch` should have the following structure:
{
"observation.state" (optional): (B, state_dim) batch of robot states.
[robot_state_feature] (optional): (B, state_dim) batch of robot states.
"observation.images": (B, n_cameras, C, H, W) batch of images.
[image_features]: (B, n_cameras, C, H, W) batch of images.
AND/OR
"observation.environment_state": (B, env_dim) batch of environment states.
[env_state_feature]: (B, env_dim) batch of environment states.
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
[action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
}
Returns:
@@ -411,12 +425,12 @@ class ACT(nn.Module):
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.use_robot_state:
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.use_robot_state:
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
@@ -430,7 +444,7 @@ class ACT(nn.Module):
# sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token.
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.use_robot_state else 1),
(batch_size, 2 if self.config.robot_state_feature else 1),
False,
device=batch["observation.state"].device,
)
@@ -463,16 +477,16 @@ class ACT(nn.Module):
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.use_robot_state:
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.use_env_state:
if self.config.env_state_feature:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings.
if self.use_images:
if self.config.image_features:
all_cam_features = []
all_cam_pos_embeds = []

View File

@@ -16,9 +16,15 @@
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("diffusion")
@dataclass
class DiffusionConfig:
class DiffusionConfig(PreTrainedConfig):
"""Configuration class for DiffusionPolicy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
@@ -102,26 +108,17 @@ class DiffusionConfig:
horizon: int = 16
n_action_steps: int = 8
input_shapes: dict[str, list[int]] = field(
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"observation.image": [3, 96, 96],
"observation.state": [2],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
# Architecture / modeling.
# Vision backbone.
@@ -154,39 +151,23 @@ class DiffusionConfig:
# Loss computation
do_mask_loss_for_padding: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 500
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if len(image_keys) > 0:
if self.crop_shape is not None:
for image_key in image_keys:
if (
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError(
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
)
# Check that all input images have the same shape.
first_image_key = next(iter(image_keys))
for image_key in image_keys:
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
raise ValueError(
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
"expect all image shapes to match."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(
@@ -207,3 +188,50 @@ class DiffusionConfig:
"The horizon should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
)
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -31,35 +31,32 @@ import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
get_output_shape,
populate_queues,
)
class DiffusionPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "diffusion-policy"],
):
class DiffusionPolicy(PreTrainedPolicy):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
"""
config_class = DiffusionConfig
name = "diffusion"
def __init__(
self,
config: DiffusionConfig | None = None,
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
@@ -69,18 +66,16 @@ class DiffusionPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = DiffusionConfig()
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
@@ -88,20 +83,20 @@ class DiffusionPolicy(
self.diffusion = DiffusionModel(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.use_env_state = "observation.environment_state" in config.input_shapes
self.reset()
def get_optim_params(self) -> dict:
return self.diffusion.parameters()
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
if len(self.expected_image_keys) > 0:
if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.use_env_state:
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad
@@ -127,9 +122,11 @@ class DiffusionPolicy(
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@@ -149,9 +146,11 @@ class DiffusionPolicy(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
@@ -176,12 +175,9 @@ class DiffusionModel(nn.Module):
self.config = config
# Build observation encoders (depending on which observations are provided).
global_cond_dim = config.input_shapes["observation.state"][0]
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self._use_images = False
self._use_env_state = False
if num_images > 0:
self._use_images = True
global_cond_dim = self.config.robot_state_feature.shape[0]
if self.config.image_features:
num_images = len(self.config.image_features)
if self.config.use_separate_rgb_encoder_per_camera:
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
self.rgb_encoder = nn.ModuleList(encoders)
@@ -189,9 +185,8 @@ class DiffusionModel(nn.Module):
else:
self.rgb_encoder = DiffusionRgbEncoder(config)
global_cond_dim += self.rgb_encoder.feature_dim * num_images
if "observation.environment_state" in config.input_shapes:
self._use_env_state = True
global_cond_dim += config.input_shapes["observation.environment_state"][0]
if self.config.env_state_feature:
global_cond_dim += self.config.env_state_feature.shape[0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
@@ -220,7 +215,7 @@ class DiffusionModel(nn.Module):
# Sample prior.
sample = torch.randn(
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
dtype=dtype,
device=device,
generator=generator,
@@ -242,10 +237,10 @@ class DiffusionModel(nn.Module):
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
global_cond_feats = [batch["observation.state"]]
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
global_cond_feats = [batch[OBS_ROBOT]]
# Extract image features.
if self._use_images:
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
@@ -272,8 +267,8 @@ class DiffusionModel(nn.Module):
)
global_cond_feats.append(img_features)
if self._use_env_state:
global_cond_feats.append(batch["observation.environment_state"])
if self.config.env_state_feature:
global_cond_feats.append(batch[OBS_ENV])
# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
@@ -443,7 +438,7 @@ class SpatialSoftmax(nn.Module):
class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.
"""Encodes an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first.
"""
@@ -482,19 +477,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# height and width from `config.image_features`.
# Note: we have a check in the config class to make sure all images have the same shape.
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
@@ -611,7 +603,7 @@ class DiffusionConditionalUnet1d(nn.Module):
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list(
in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
)
@@ -666,7 +658,7 @@ class DiffusionConditionalUnet1d(nn.Module):
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:

View File

@@ -13,99 +13,132 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from omegaconf import DictConfig, OmegaConf
import torch
from torch import nn
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_safe_torch_device
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
logging.warning(
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
)
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
# issues with mutable defaults. This filter changes all lists to tuples.
def list_to_tuple(item):
return tuple(item) if isinstance(item, list) else item
policy_cfg = policy_cfg_class(
**{
k: list_to_tuple(v)
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
if k in expected_kwargs
}
)
return policy_cfg
def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
def get_policy_class(name: str) -> PreTrainedPolicy:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc":
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy, TDMPCConfig
return TDMPCPolicy
elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy, DiffusionConfig
return DiffusionPolicy
elif name == "act":
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.act.modeling_act import ACTPolicy
return ACTPolicy, ACTConfig
return ACTPolicy
elif name == "vqbet":
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy, VQBeTConfig
return VQBeTPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
if policy_type == "tdmpc":
return TDMPCConfig(**kwargs)
elif policy_type == "diffusion":
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
def make_policy(
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
) -> Policy:
cfg: PreTrainedConfig,
device: str | torch.device,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy:
"""Make an instance of a policy class.
This function exists because (for now) we need to parse features from either a dataset or an environment
in order to properly dimension and instantiate a policy for that dataset or environment.
Args:
hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is
provided, only `hydra_cfg.policy.name` is used while everything else is ignored.
pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a
directory containing weights saved using `Policy.save_pretrained`. Note that providing this
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`.
dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must
be provided when initializing a new policy, and must not be provided when loading a pretrained
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
be loaded with the weights from that path.
device (str): the device to load the policy onto.
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
provided if ds_meta is not. Defaults to None.
Raises:
ValueError: Either ds_meta or env and env_cfg must be provided.
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
Returns:
PreTrainedPolicy: _description_
"""
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
raise ValueError(
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
if bool(ds_meta) == bool(env_cfg):
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
# you want this op to be added in priority during the prototype phase of this feature, please comment on
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS.
if cfg.type == "vqbet" and str(device) == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
)
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
policy_cls = get_policy_class(cfg.type)
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
if pretrained_policy_name_or_path is None:
# Make a fresh policy.
policy = policy_cls(policy_cfg, dataset_stats)
kwargs = {}
if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
kwargs["dataset_stats"] = ds_meta.stats
else:
if not cfg.pretrained_path:
logging.warning(
"You are instantiating a policy from scratch and its features are parsed from an environment "
"rather than a dataset. Normalization modules inside the policy will have infinite values "
"by default without stats from a dataset."
)
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
# huggingface_hub should make it possible to avoid the hack:
# https://github.com/huggingface/huggingface_hub/pull/2274.
policy = policy_cls(policy_cfg)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
policy = policy_cls.from_pretrained(**kwargs)
else:
# Make a fresh policy.
policy = policy_cls(**kwargs)
policy.to(get_safe_torch_device(hydra_cfg.device))
policy.to(device)
assert isinstance(policy, nn.Module)
return policy

View File

@@ -16,10 +16,12 @@
import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
def create_stats_buffers(
shapes: dict[str, list[int]],
modes: dict[str, str],
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""
@@ -34,12 +36,16 @@ def create_stats_buffers(
"""
stats_buffers = {}
for key, mode in modes.items():
assert mode in ["mean_std", "min_max"]
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
shape = tuple(shapes[key])
assert isinstance(norm_mode, NormalizationMode)
if "image" in key:
shape = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
@@ -52,7 +58,7 @@ def create_stats_buffers(
# we assert they are not infinity anymore.
buffer = {}
if mode == "mean_std":
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
@@ -61,7 +67,7 @@ def create_stats_buffers(
"std": nn.Parameter(std, requires_grad=False),
}
)
elif mode == "min_max":
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
@@ -71,15 +77,15 @@ def create_stats_buffers(
}
)
if stats is not None:
if stats:
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if mode == "mean_std":
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone()
buffer["std"].data = stats[key]["std"].clone()
elif mode == "min_max":
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone()
buffer["max"].data = stats[key]["max"].clone()
@@ -99,8 +105,8 @@ class Normalize(nn.Module):
def __init__(
self,
shapes: dict[str, list[int]],
modes: dict[str, str],
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
@@ -122,10 +128,10 @@ class Normalize(nn.Module):
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
self.features = features
self.norm_map = norm_map
self.stats = stats
stats_buffers = create_stats_buffers(shapes, modes, stats)
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -133,16 +139,20 @@ class Normalize(nn.Module):
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items():
for key, ft in self.features.items():
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
@@ -152,7 +162,7 @@ class Normalize(nn.Module):
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
raise ValueError(norm_mode)
return batch
@@ -164,8 +174,8 @@ class Unnormalize(nn.Module):
def __init__(
self,
shapes: dict[str, list[int]],
modes: dict[str, str],
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
@@ -187,11 +197,11 @@ class Unnormalize(nn.Module):
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
self.features = features
self.norm_map = norm_map
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(shapes, modes, stats)
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -199,16 +209,20 @@ class Unnormalize(nn.Module):
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items():
for key, ft in self.features.items():
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
elif mode == "min_max":
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
@@ -216,5 +230,5 @@ class Unnormalize(nn.Module):
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
raise ValueError(norm_mode)
return batch

View File

@@ -1,75 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A protocol that all policies should follow.
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
subclass a base class.
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
how to implement new policies.
"""
from typing import Protocol, runtime_checkable
from torch import Tensor
@runtime_checkable
class Policy(Protocol):
"""The required interface for implementing a policy.
We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin.
"""
name: str
def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
dataset_stats: Dataset statistics to be used for normalization.
"""
def reset(self):
"""To be called whenever the environment is reset.
Does things like clearing caches.
"""
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
"""
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
@runtime_checkable
class PolicyWithUpdate(Policy, Protocol):
def update(self):
"""An update method that is to be called after a training optimization step.
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
target model, or incrementing an internal buffer).
"""

View File

@@ -0,0 +1,182 @@
import abc
import logging
import os
from pathlib import Path
from typing import Type, TypeVar
import packaging
import safetensors
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.common.utils.hub import HubMixin
from lerobot.configs.policies import PreTrainedConfig
T = TypeVar("T", bound="PreTrainedPolicy")
DEFAULT_POLICY_CARD = """
---
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/model-cards
{{ card_data }}
---
This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot):
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
"""
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
Base class for policy models.
"""
config_class: None
name: None
def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
super().__init__()
if not isinstance(config, PreTrainedConfig):
raise ValueError(
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
"`PreTrainedConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.config = config
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not getattr(cls, "config_class", None):
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'")
def _save_pretrained(self, save_directory: Path) -> None:
self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
@classmethod
def from_pretrained(
cls: Type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
map_location: str = "cpu",
strict: bool = False,
**kwargs,
) -> T:
"""
The policy is set in evaluation mode by default using `policy.eval()` (dropout modules are
deactivated). To train it, you should first set it back in training mode with `policy.train()`.
"""
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs)
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
policy.to(map_location)
policy.eval()
return policy
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
load_model_as_safetensor(model, model_file, strict=strict)
if map_location != "cpu":
logging.warning(
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Please update safetensors to version 0.4.3 or above for improved performance."
)
model.to(map_location)
else:
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return model
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
# card = ModelCard.from_template(
# card_data=self._hub_mixin_info.model_card_data,
# template_str=self._hub_mixin_info.model_card_template,
# repo_url=self._hub_mixin_info.repo_url,
# docs_url=self._hub_mixin_info.docs_url,
# **kwargs,
# )
# return card
@abc.abstractmethod
def get_optim_params(self) -> dict:
"""
Returns the policy-specific parameters dict to be passed on to the optimizer.
"""
raise NotImplementedError
@abc.abstractmethod
def reset(self):
"""To be called whenever the environment is reset.
Does things like clearing caches.
"""
raise NotImplementedError
@abc.abstractmethod
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
"""
raise NotImplementedError
@abc.abstractmethod
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
raise NotImplementedError

View File

@@ -16,9 +16,14 @@
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("tdmpc")
@dataclass
class TDMPCConfig:
class TDMPCConfig(PreTrainedConfig):
"""Configuration class for TDMPCPolicy.
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
@@ -102,27 +107,19 @@ class TDMPCConfig:
"""
# Input / output structure.
n_obs_steps: int = 1
n_action_repeats: int = 2
horizon: int = 5
n_action_steps: int = 1
input_shapes: dict[str, list[int]] = field(
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ENV": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.MIN_MAX,
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
# Architecture / modeling.
# Neural networks.
@@ -159,32 +156,27 @@ class TDMPCConfig:
# Target model.
target_model_momentum: float = 0.995
# Training presets
optimizer_lr: float = 3e-4
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) > 1:
raise ValueError(
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
)
if len(image_keys) > 0:
image_key = next(iter(image_keys))
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
)
if self.n_gaussian_samples <= 0:
raise ValueError(
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
)
if self.output_normalization_modes != {"action": "min_max"}:
if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
raise ValueError(
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
"information."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.n_action_steps > 1:
if self.n_action_repeats != 1:
raise ValueError(
@@ -194,3 +186,35 @@ class TDMPCConfig:
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(lr=self.optimizer_lr)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
# There should only be one image key.
if len(self.image_features) > 1:
raise ValueError(
f"{self.__class__.__name__} handles at most one image for now. Got image keys {self.image_features}."
)
if len(self.image_features) > 0:
image_ft = next(iter(self.image_features.values()))
if image_ft.shape[-2] != image_ft.shape[-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
@property
def observation_delta_indices(self) -> list:
return list(range(self.horizon + 1))
@property
def action_delta_indices(self) -> list:
return list(range(self.horizon))
@property
def reward_delta_indices(self) -> None:
return list(range(self.horizon))

View File

@@ -33,21 +33,16 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
class TDMPCPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "tdmpc"],
):
class TDMPCPolicy(PreTrainedPolicy):
"""Implementation of TD-MPC learning + inference.
Please note several warnings for this policy.
@@ -65,11 +60,10 @@ class TDMPCPolicy(
match our xarm environment.
"""
config_class = TDMPCConfig
name = "tdmpc"
def __init__(
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
@@ -77,42 +71,28 @@ class TDMPCPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = TDMPCConfig()
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
for param in self.model_target.parameters():
param.requires_grad = False
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
self._use_image = False
self._use_env_state = False
if len(image_keys) > 0:
assert len(image_keys) == 1
self._use_image = True
self.input_image_key = image_keys[0]
if "observation.environment_state" in config.input_shapes:
self._use_env_state = True
self.reset()
def get_optim_params(self) -> dict:
return self.parameters()
def reset(self):
"""
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
@@ -122,9 +102,9 @@ class TDMPCPolicy(
"observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
}
if self._use_image:
if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state:
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=1)
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step.
@@ -134,9 +114,9 @@ class TDMPCPolicy(
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
if self._use_image:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch["observation.image"] = batch[next(iter(self.config.image_features))]
self._queues = populate_queues(self._queues, batch)
@@ -151,9 +131,9 @@ class TDMPCPolicy(
# NOTE: Order of observations matters here.
encode_keys = []
if self._use_image:
if self.config.image_features:
encode_keys.append("observation.image")
if self._use_env_state:
if self.config.env_state_feature:
encode_keys.append("observation.environment_state")
encode_keys.append("observation.state")
z = self.model.encode({k: batch[k] for k in encode_keys})
@@ -196,7 +176,7 @@ class TDMPCPolicy(
self.config.horizon,
self.config.n_pi_samples,
batch_size,
self.config.output_shapes["action"][0],
self.config.action_feature.shape[0],
device=device,
)
if self.config.n_pi_samples > 0:
@@ -215,7 +195,7 @@ class TDMPCPolicy(
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
@@ -228,7 +208,7 @@ class TDMPCPolicy(
self.config.horizon,
self.config.n_gaussian_samples,
batch_size,
self.config.output_shapes["action"][0],
self.config.action_feature.shape[0],
device=std.device,
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
@@ -330,9 +310,9 @@ class TDMPCPolicy(
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
if self._use_image:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch["observation.image"] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
info = {}
@@ -347,7 +327,7 @@ class TDMPCPolicy(
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations.
if self._use_image and self.config.max_random_shift_ratio > 0:
if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"],
@@ -360,7 +340,7 @@ class TDMPCPolicy(
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[
"observation.image" if self._use_image else "observation.environment_state"
"observation.image" if self.config.image_features else "observation.environment_state"
].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.
@@ -543,7 +523,7 @@ class TDMPCTOLD(nn.Module):
self.config = config
self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -554,7 +534,7 @@ class TDMPCTOLD(nn.Module):
nn.Sigmoid(),
)
self._reward = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -569,12 +549,12 @@ class TDMPCTOLD(nn.Module):
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.output_shapes["action"][0]),
nn.Linear(config.mlp_dim, config.action_feature.shape[0]),
)
self._Qs = nn.ModuleList(
[
nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -714,10 +694,13 @@ class TDMPCObservationEncoder(nn.Module):
super().__init__()
self.config = config
if "observation.image" in config.input_shapes:
if config.image_features:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
next(iter(config.image_features.values())).shape[0],
config.image_encoder_hidden_dim,
7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
@@ -727,9 +710,8 @@ class TDMPCObservationEncoder(nn.Module):
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
@@ -738,19 +720,19 @@ class TDMPCObservationEncoder(nn.Module):
nn.Sigmoid(),
)
)
if "observation.state" in config.input_shapes:
if config.robot_state_feature:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
if "observation.environment_state" in config.input_shapes:
if config.env_state_feature:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
),
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -765,12 +747,16 @@ class TDMPCObservationEncoder(nn.Module):
"""
feat = []
# NOTE: Order of observations matters here.
if "observation.image" in self.config.input_shapes:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
if self.config.image_features:
feat.append(
flatten_forward_unflatten(
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
)
)
if self.config.env_state_feature:
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
if self.config.robot_state_feature:
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
return torch.stack(feat, dim=0).mean(0)

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
@@ -47,3 +48,20 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
Note: assumes that all parameters have the same dtype.
"""
return next(iter(module.parameters())).dtype
def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
"""
Calculates the output shape of a PyTorch module given an input shape.
Args:
module (nn.Module): a PyTorch module
input_shape (tuple): A tuple representing the input shape, e.g., (batch_size, channels, height, width)
Returns:
tuple: The output shape of the module.
"""
dummy_input = torch.zeros(size=input_shape)
with torch.inference_mode():
output = module(dummy_input)
return tuple(output.shape)

View File

@@ -18,9 +18,15 @@
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("vqbet")
@dataclass
class VQBeTConfig:
class VQBeTConfig(PreTrainedConfig):
"""Configuration class for VQ-BeT.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
@@ -90,26 +96,13 @@ class VQBeTConfig:
n_action_pred_token: int = 3
action_chunk_size: int = 5
input_shapes: dict[str, list[int]] = field(
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"observation.image": [3, 96, 96],
"observation.state": [2],
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
# Architecture / modeling.
# Vision backbone.
@@ -139,29 +132,69 @@ class VQBeTConfig:
bet_softmax_temperature: float = 0.1
sequentially_select: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
optimizer_vqvae_lr: float = 1e-3
optimizer_vqvae_weight_decay: float = 1e-4
scheduler_warmup_steps: int = 500
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> VQBeTSchedulerConfig:
return VQBeTSchedulerConfig(
num_warmup_steps=self.scheduler_warmup_steps,
num_vqvae_training_steps=self.n_vqvae_training_steps,
)
def validate_features(self) -> None:
# Note: this check was previously performed inside VQBeTRgbEncoder in the form of
# assert len(image_keys) == 1
if not len(self.image_features) == 1:
raise ValueError("You must provide only one image among the inputs.")
if self.crop_shape is not None:
for image_key in image_keys:
if (
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
first_image_key = next(iter(image_keys))
for image_key in image_keys:
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
"expect all image shapes to match."
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -16,7 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from collections import deque
from typing import Callable, List
@@ -26,29 +25,23 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from torch.optim.lr_scheduler import LambdaLR
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
# ruff: noqa: N806
class VQBeTPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "vqbet"],
):
class VQBeTPolicy(PreTrainedPolicy):
"""
VQ-BeT Policy as per "Behavior Generation with Latent Actions"
"""
config_class = VQBeTConfig
name = "vqbet"
def __init__(
@@ -63,26 +56,62 @@ class VQBeTPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = VQBeTConfig()
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.vqbet = VQBeTModel(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.reset()
def get_optim_params(self) -> dict:
vqvae_params = (
list(self.vqbet.action_head.vqvae_model.encoder.parameters())
+ list(self.vqbet.action_head.vqvae_model.decoder.parameters())
+ list(self.vqbet.action_head.vqvae_model.vq_layer.parameters())
)
decay_params, no_decay_params = self.vqbet.policy.configure_parameters()
decay_params = (
decay_params
+ list(self.vqbet.rgb_encoder.parameters())
+ list(self.vqbet.state_projector.parameters())
+ list(self.vqbet.rgb_feature_projector.parameters())
+ [self.vqbet.action_token]
+ list(self.vqbet.action_head.map_to_cbet_preds_offset.parameters())
)
if self.config.sequentially_select:
decay_params = (
decay_params
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
else:
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
return [
{
"params": decay_params,
},
{
"params": vqvae_params,
"weight_decay": self.config.optimizer_vqvae_weight_decay,
"lr": self.config.optimizer_vqvae_lr,
},
{
"params": no_decay_params,
"weight_decay": 0.0,
},
]
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
@@ -105,7 +134,7 @@ class VQBeTPolicy(
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@@ -131,7 +160,7 @@ class VQBeTPolicy(
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
@@ -288,14 +317,14 @@ class VQBeTModel(nn.Module):
self.config = config
self.rgb_encoder = VQBeTRgbEncoder(config)
self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self.num_images = len(self.config.image_features)
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP(
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
)
self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
@@ -350,10 +379,10 @@ class VQBeTModel(nn.Module):
# get action features (pass through GPT)
features = self.policy(input_tokens)
# len(self.config.input_shapes) is the number of different observation modes.
# len(self.config.input_features) is the number of different observation modes.
# this line gets the index of action prompt tokens.
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
self.config.input_shapes
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
self.config.input_features
)
# only extract the output tokens at the position of action query:
@@ -392,7 +421,7 @@ class VQBeTHead(nn.Module):
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0]`.
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.action_feature.shape[0]`.
"""
super().__init__()
@@ -419,7 +448,7 @@ class VQBeTHead(nn.Module):
self.vqvae_model.vqvae_num_layers
* self.config.vqvae_n_embed
* config.action_chunk_size
* config.output_shapes["action"][0],
* config.action_feature.shape[0],
],
)
# loss
@@ -623,84 +652,6 @@ class VQBeTHead(nn.Module):
return loss_dict
class VQBeTOptimizer(torch.optim.Adam):
def __init__(self, policy, cfg):
vqvae_params = (
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
)
decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
decay_params = (
decay_params
+ list(policy.vqbet.rgb_encoder.parameters())
+ list(policy.vqbet.state_projector.parameters())
+ list(policy.vqbet.rgb_feature_projector.parameters())
+ [policy.vqbet.action_token]
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
)
if cfg.policy.sequentially_select:
decay_params = (
decay_params
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
else:
decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
optim_groups = [
{
"params": decay_params,
"weight_decay": cfg.training.adam_weight_decay,
"lr": cfg.training.lr,
},
{
"params": vqvae_params,
"weight_decay": 0.0001,
"lr": cfg.training.vqvae_lr,
},
{
"params": no_decay_params,
"weight_decay": 0.0,
"lr": cfg.training.lr,
},
]
super().__init__(
optim_groups,
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
)
class VQBeTScheduler(nn.Module):
def __init__(self, optimizer, cfg):
super().__init__()
n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
num_warmup_steps = cfg.training.lr_warmup_steps
num_training_steps = cfg.training.offline_steps
num_cycles = 0.5
def lr_lambda(current_step):
if current_step < n_vqvae_training_steps:
return float(1)
else:
current_step = current_step - n_vqvae_training_steps
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1)
def step(self):
self.lr_scheduler.step()
class VQBeTRgbEncoder(nn.Module):
"""Encode an RGB image into a 1D feature vector.
@@ -743,19 +694,15 @@ class VQBeTRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
# height and width from `config.image_features`.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
@@ -844,7 +791,7 @@ class VqVae(nn.Module):
)
self.encoder = MLP(
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
hidden_channels=[
config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim,
@@ -856,7 +803,7 @@ class VqVae(nn.Module):
hidden_channels=[
config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim,
self.config.output_shapes["action"][0] * self.config.action_chunk_size,
self.config.action_feature.shape[0] * self.config.action_chunk_size,
],
)
@@ -872,9 +819,9 @@ class VqVae(nn.Module):
# given latent vector, this function outputs the decoded action.
output = self.decoder(latent)
if self.config.action_chunk_size == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
def get_code(self, state):
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)

View File

@@ -0,0 +1,100 @@
import abc
from dataclasses import dataclass
import draccus
@dataclass
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@CameraConfig.register_subclass("opencv")
@dataclass
class OpenCVCameraConfig(CameraConfig):
"""
Example of tested options for Intel Real Sense D405:
```python
OpenCVCameraConfig(0, 30, 640, 480)
OpenCVCameraConfig(0, 60, 640, 480)
OpenCVCameraConfig(0, 90, 640, 480)
OpenCVCameraConfig(0, 30, 1280, 720)
```
"""
camera_index: int
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
rotation: int | None = None
mock: bool = False
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
@CameraConfig.register_subclass("intelrealsense")
@dataclass
class IntelRealSenseCameraConfig(CameraConfig):
"""
Example of tested options for Intel Real Sense D405:
```python
IntelRealSenseCameraConfig(128422271347, 30, 640, 480)
IntelRealSenseCameraConfig(128422271347, 60, 640, 480)
IntelRealSenseCameraConfig(128422271347, 90, 640, 480)
IntelRealSenseCameraConfig(128422271347, 30, 1280, 720)
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True)
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90)
```
"""
name: str | None = None
serial_number: int | None = None
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
use_depth: bool = False
force_hardware_reset: bool = True
rotation: int | None = None
mock: bool = False
def __post_init__(self):
# bool is stronger than is None, since it works with empty strings
if bool(self.name) and bool(self.serial_number):
raise ValueError(
f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided."
)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
if at_least_one_is_not_none and at_least_one_is_none:
raise ValueError(
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
)
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")

View File

@@ -11,13 +11,13 @@ import threading
import time
import traceback
from collections import Counter
from dataclasses import dataclass, replace
from pathlib import Path
from threading import Thread
import numpy as np
from PIL import Image
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
@@ -94,7 +94,10 @@ def save_images_from_cameras(
cameras = []
for cam_sn in serial_numbers:
print(f"{cam_sn=}")
camera = IntelRealSenseCamera(cam_sn, fps=fps, width=width, height=height, mock=mock)
config = IntelRealSenseCameraConfig(
serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock
)
camera = IntelRealSenseCamera(config)
camera.connect()
print(
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
@@ -149,51 +152,6 @@ def save_images_from_cameras(
camera.disconnect()
@dataclass
class IntelRealSenseCameraConfig:
"""
Example of tested options for Intel Real Sense D405:
```python
IntelRealSenseCameraConfig(30, 640, 480)
IntelRealSenseCameraConfig(60, 640, 480)
IntelRealSenseCameraConfig(90, 640, 480)
IntelRealSenseCameraConfig(30, 1280, 720)
IntelRealSenseCameraConfig(30, 640, 480, use_depth=True)
IntelRealSenseCameraConfig(30, 640, 480, rotation=90)
```
"""
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
use_depth: bool = False
force_hardware_reset: bool = True
rotation: int | None = None
mock: bool = False
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
if at_least_one_is_not_none and at_least_one_is_none:
raise ValueError(
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
)
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
class IntelRealSenseCamera:
"""
The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
@@ -209,33 +167,35 @@ class IntelRealSenseCamera:
When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
of the given camera will be used.
Example of usage:
Example of instantiating with a serial number:
```python
# Instantiate with its serial number
camera = IntelRealSenseCamera(128422271347)
# Or by its name if it's unique
camera = IntelRealSenseCamera.init_from_name("Intel RealSense D405")
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
config = IntelRealSenseCameraConfig(serial_number=128422271347)
camera = IntelRealSenseCamera(config)
camera.connect()
color_image = camera.read()
# when done using the camera, consider disconnecting
camera.disconnect()
```
Example of instantiating with a name if it's unique:
```
config = IntelRealSenseCameraConfig(name="Intel RealSense D405")
```
Example of changing default fps, width, height and color_mode:
```python
camera = IntelRealSenseCamera(serial_number, fps=30, width=1280, height=720)
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera
camera = IntelRealSenseCamera(serial_number, fps=90, width=640, height=480)
camera = connect()
camera = IntelRealSenseCamera(serial_number, fps=90, width=640, height=480, color_mode="bgr")
camera = connect()
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720)
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480)
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr")
# Note: might error out upon `camera.connect()` if these settings are not compatible with the camera
```
Example of returning depth:
```python
camera = IntelRealSenseCamera(serial_number, use_depth=True)
config = IntelRealSenseCameraConfig(serial_number=128422271347, use_depth=True)
camera = IntelRealSenseCamera(config)
camera.connect()
color_image, depth_map = camera.read()
```
@@ -243,17 +203,13 @@ class IntelRealSenseCamera:
def __init__(
self,
serial_number: int,
config: IntelRealSenseCameraConfig | None = None,
**kwargs,
config: IntelRealSenseCameraConfig,
):
if config is None:
config = IntelRealSenseCameraConfig()
# Overwrite the config arguments using kwargs
config = replace(config, **kwargs)
self.serial_number = serial_number
self.config = config
if config.name is not None:
self.serial_number = self.find_serial_number_from_name(config.name)
else:
self.serial_number = config.serial_number
self.fps = config.fps
self.width = config.width
self.height = config.height
@@ -285,8 +241,7 @@ class IntelRealSenseCamera:
elif config.rotation == 180:
self.rotation = cv2.ROTATE_180
@classmethod
def init_from_name(cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs):
def find_serial_number_from_name(self, name):
camera_infos = find_cameras()
camera_names = [cam["name"] for cam in camera_infos]
this_name_count = Counter(camera_names)[name]
@@ -299,13 +254,7 @@ class IntelRealSenseCamera:
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
cam_sn = name_to_serial_dict[name]
if config is None:
config = IntelRealSenseCameraConfig()
# Overwrite the config arguments using kwargs
config = replace(config, **kwargs)
return cls(serial_number=cam_sn, config=config, **kwargs)
return cam_sn
def connect(self):
if self.is_connected:

View File

@@ -9,13 +9,13 @@ import platform
import shutil
import threading
import time
from dataclasses import dataclass, replace
from pathlib import Path
from threading import Thread
import numpy as np
from PIL import Image
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
@@ -126,7 +126,8 @@ def save_images_from_cameras(
print("Connecting cameras")
cameras = []
for cam_idx in camera_ids:
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height, mock=mock)
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
camera = OpenCVCamera(config)
camera.connect()
print(
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
@@ -175,39 +176,6 @@ def save_images_from_cameras(
print(f"Images have been saved to {images_dir}")
@dataclass
class OpenCVCameraConfig:
"""
Example of tested options for Intel Real Sense D405:
```python
OpenCVCameraConfig(30, 640, 480)
OpenCVCameraConfig(60, 640, 480)
OpenCVCameraConfig(90, 640, 480)
OpenCVCameraConfig(30, 1280, 720)
```
"""
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
rotation: int | None = None
mock: bool = False
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
class OpenCVCamera:
"""
The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
@@ -227,7 +195,10 @@ class OpenCVCamera:
Example of usage:
```python
camera = OpenCVCamera(camera_index=0)
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
config = OpenCVCameraConfig(camera_index=0)
camera = OpenCVCamera(config)
camera.connect()
color_image = camera.read()
# when done using the camera, consider disconnecting
@@ -236,25 +207,16 @@ class OpenCVCamera:
Example of changing default fps, width, height and color_mode:
```python
camera = OpenCVCamera(0, fps=30, width=1280, height=720)
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera
camera = OpenCVCamera(0, fps=90, width=640, height=480)
camera = connect()
camera = OpenCVCamera(0, fps=90, width=640, height=480, color_mode="bgr")
camera = connect()
config = OpenCVCameraConfig(camera_index=0, fps=30, width=1280, height=720)
config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480)
config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480, color_mode="bgr")
# Note: might error out open `camera.connect()` if these settings are not compatible with the camera
```
"""
def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = None, **kwargs):
if config is None:
config = OpenCVCameraConfig()
# Overwrite config arguments using kwargs
config = replace(config, **kwargs)
self.camera_index = camera_index
def __init__(self, config: OpenCVCameraConfig):
self.config = config
self.camera_index = config.camera_index
self.port = None
# Linux uses ports for connecting to cameras
@@ -266,7 +228,7 @@ class OpenCVCamera:
# Retrieve the camera index from a potentially symlinked path
self.camera_index = get_camera_index_from_unix_port(self.port)
else:
raise ValueError(f"Please check the provided camera_index: {camera_index}")
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
self.fps = config.fps
self.width = config.width

View File

@@ -2,6 +2,12 @@ from typing import Protocol
import numpy as np
from lerobot.common.robot_devices.cameras.configs import (
CameraConfig,
IntelRealSenseCameraConfig,
OpenCVCameraConfig,
)
# Defines a camera type
class Camera(Protocol):
@@ -9,3 +15,39 @@ class Camera(Protocol):
def read(self, temporary_color: str | None = None) -> np.ndarray: ...
def async_read(self) -> np.ndarray: ...
def disconnect(self): ...
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[Camera]:
cameras = {}
for key, cfg in camera_configs.items():
if cfg.type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
cameras[key] = OpenCVCamera(cfg)
elif cfg.type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
cameras[key] = IntelRealSenseCamera(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
return cameras
def make_camera(camera_type, **kwargs) -> Camera:
if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
config = OpenCVCameraConfig(**kwargs)
return OpenCVCamera(config)
elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
config = IntelRealSenseCameraConfig(**kwargs)
return IntelRealSenseCamera(config)
else:
raise ValueError(f"The camera type '{camera_type}' is not valid.")

View File

@@ -0,0 +1,146 @@
import logging
from dataclasses import dataclass
from pathlib import Path
import draccus
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@dataclass
class ControlConfig(draccus.ChoiceRegistry):
pass
@ControlConfig.register_subclass("calibrate")
@dataclass
class CalibrateControlConfig(ControlConfig):
# List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`)
arms: list[str] | None = None
@ControlConfig.register_subclass("teleoperate")
@dataclass
class TeleoperateControlConfig(ControlConfig):
# Limit the maximum frames per second. By default, no limit.
fps: int | None = None
teleop_time_s: float | None = None
# Display all cameras on screen
display_cameras: bool = True
@ControlConfig.register_subclass("record")
@dataclass
class RecordControlConfig(ControlConfig):
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
policy: PreTrainedConfig | None = None
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
device: str | None = None # cuda | cpu | mps
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int | None = None
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
warmup_time_s: int | float = 10
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.
run_compute_stats: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Display all cameras on screen
display_cameras: bool = True
# Use vocal synthesis to read events.
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
local_files_only: bool = False
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("control.policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("control.policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
# When no device or use_amp are given, use the one from training config.
if self.device is None or self.use_amp is None:
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
if self.device is None:
self.device = train_cfg.device
if self.use_amp is None:
self.use_amp = train_cfg.use_amp
# Automatically switch to available device if necessary
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@ControlConfig.register_subclass("replay")
@dataclass
class ReplayControlConfig(ControlConfig):
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str
# Index of the episode to replay.
episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
# Limit the frames per second. By default, uses the dataset fps.
fps: int | None = None
# Use vocal synthesis to read events.
play_sounds: bool = True
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
local_files_only: bool = False
@dataclass
class ControlPipelineConfig:
robot: RobotConfig
control: ControlConfig
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["control.policy"]

View File

@@ -19,11 +19,9 @@ from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.common.utils.utils import get_safe_torch_device, has_method
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
@@ -89,10 +87,6 @@ def is_headless():
return True
def has_method(_object: object, method_name: str):
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
@@ -161,26 +155,6 @@ def init_keyboard_listener():
return listener, events
def init_policy(pretrained_policy_name_or_path, policy_overrides):
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
use_amp = hydra_cfg.use_amp
policy_fps = hydra_cfg.env.fps
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
return policy, policy_fps, device, use_amp
def warmup_record(
robot,
events,
@@ -233,9 +207,9 @@ def control_loop(
dataset: LeRobotDataset | None = None,
events=None,
policy=None,
device=None,
use_amp=None,
fps=None,
device: torch.device | str | None = None,
use_amp: bool | None = None,
fps: int | None = None,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
@@ -253,6 +227,9 @@ def control_loop(
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
if isinstance(device, str):
device = get_safe_torch_device(device)
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
@@ -324,21 +301,21 @@ def stop_recording(robot, listener, display_cameras):
cv2.destroyAllWindows()
def sanity_check_dataset_name(repo_id, policy):
def sanity_check_dataset_name(repo_id, policy_cfg):
_, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy
# or repo_id starts with "eval_" and there is a policy
# Check if dataset_name starts with "eval_" but policy is missing
if dataset_name.startswith("eval_") and policy is None:
if dataset_name.startswith("eval_") and policy_cfg is None:
raise ValueError(
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
)
# Check if dataset_name does not start with "eval_" but policy is provided
if not dataset_name.startswith("eval_") and policy is not None:
if not dataset_name.startswith("eval_") and policy_cfg is not None:
raise ValueError(
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy})."
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
)

View File

@@ -0,0 +1,27 @@
import abc
from dataclasses import dataclass
import draccus
@dataclass
class MotorsBusConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@MotorsBusConfig.register_subclass("dynamixel")
@dataclass
class DynamixelMotorsBusConfig(MotorsBusConfig):
port: str
motors: dict[str, tuple[int, str]]
mock: bool = False
@MotorsBusConfig.register_subclass("feetech")
@dataclass
class FeetechMotorsBusConfig(MotorsBusConfig):
port: str
motors: dict[str, tuple[int, str]]
mock: bool = False

View File

@@ -8,6 +8,7 @@ from copy import deepcopy
import numpy as np
import tqdm
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc
@@ -252,7 +253,6 @@ class JointOutOfRangeError(Exception):
class DynamixelMotorsBus:
# TODO(rcadene): Add a script to find the motor indices without DynamixelWizzard2
"""
The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
@@ -274,10 +274,11 @@ class DynamixelMotorsBus:
motor_index = 6
motor_model = "xl330-m288"
motors_bus = DynamixelMotorsBus(
config = DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0031751",
motors={motor_name: (motor_index, motor_model)},
)
motors_bus = DynamixelMotorsBus(config)
motors_bus.connect()
position = motors_bus.read("Present_Position")
@@ -293,23 +294,14 @@ class DynamixelMotorsBus:
def __init__(
self,
port: str,
motors: dict[str, tuple[int, str]],
extra_model_control_table: dict[str, list[tuple]] | None = None,
extra_model_resolution: dict[str, int] | None = None,
mock=False,
config: DynamixelMotorsBusConfig,
):
self.port = port
self.motors = motors
self.mock = mock
self.port = config.port
self.motors = config.motors
self.mock = config.mock
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
if extra_model_control_table:
self.model_ctrl_table.update(extra_model_control_table)
self.model_resolution = deepcopy(MODEL_RESOLUTION)
if extra_model_resolution:
self.model_resolution.update(extra_model_resolution)
self.port_handler = None
self.packet_handler = None

View File

@@ -8,6 +8,7 @@ from copy import deepcopy
import numpy as np
import tqdm
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc
@@ -252,10 +253,11 @@ class FeetechMotorsBus:
motor_index = 6
motor_model = "sts3215"
motors_bus = FeetechMotorsBus(
config = FeetechMotorsBusConfig(
port="/dev/tty.usbmodem575E0031751",
motors={motor_name: (motor_index, motor_model)},
)
motors_bus = FeetechMotorsBus(config)
motors_bus.connect()
position = motors_bus.read("Present_Position")
@@ -271,23 +273,14 @@ class FeetechMotorsBus:
def __init__(
self,
port: str,
motors: dict[str, tuple[int, str]],
extra_model_control_table: dict[str, list[tuple]] | None = None,
extra_model_resolution: dict[str, int] | None = None,
mock=False,
config: FeetechMotorsBusConfig,
):
self.port = port
self.motors = motors
self.mock = mock
self.port = config.port
self.motors = config.motors
self.mock = config.mock
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
if extra_model_control_table:
self.model_ctrl_table.update(extra_model_control_table)
self.model_resolution = deepcopy(MODEL_RESOLUTION)
if extra_model_resolution:
self.model_resolution.update(extra_model_resolution)
self.port_handler = None
self.packet_handler = None

View File

@@ -1,5 +1,11 @@
from typing import Protocol
from lerobot.common.robot_devices.motors.configs import (
DynamixelMotorsBusConfig,
FeetechMotorsBusConfig,
MotorsBusConfig,
)
class MotorsBus(Protocol):
def motor_names(self): ...
@@ -8,3 +14,40 @@ class MotorsBus(Protocol):
def revert_calibration(self): ...
def read(self): ...
def write(self): ...
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
motors_buses = {}
for key, cfg in motors_bus_configs.items():
if cfg.type == "dynamixel":
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
motors_buses[key] = DynamixelMotorsBus(cfg)
elif cfg.type == "feetech":
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
motors_buses[key] = FeetechMotorsBus(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
return motors_buses
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
if motor_type == "dynamixel":
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
config = DynamixelMotorsBusConfig(**kwargs)
return DynamixelMotorsBus(config)
elif motor_type == "feetech":
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
config = FeetechMotorsBusConfig(**kwargs)
return FeetechMotorsBus(config)
else:
raise ValueError(f"The motor type '{motor_type}' is not valid.")

View File

@@ -0,0 +1,516 @@
import abc
from dataclasses import dataclass, field
from typing import Sequence
import draccus
from lerobot.common.robot_devices.cameras.configs import (
CameraConfig,
IntelRealSenseCameraConfig,
OpenCVCameraConfig,
)
from lerobot.common.robot_devices.motors.configs import (
DynamixelMotorsBusConfig,
FeetechMotorsBusConfig,
MotorsBusConfig,
)
@dataclass
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
# TODO(rcadene, aliberts): remove ManipulatorRobotConfig abstraction
@dataclass
class ManipulatorRobotConfig(RobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
# as the number of motors in your follower arms (assumes all follower arms have the same number of
# motors).
max_relative_target: list[float] | float | None = None
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
# gripper is not put in torque mode.
gripper_open_degree: float | None = None
mock: bool = False
def __post_init__(self):
if self.mock:
for arm in self.leader_arms.values():
if not arm.mock:
arm.mock = True
for arm in self.follower_arms.values():
if not arm.mock:
arm.mock = True
for cam in self.cameras.values():
if not cam.mock:
cam.mock = True
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
raise ValueError(
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
f"`max_relative_target` list has as many parameters as there are motors per arm. "
"Note: This feature does not yet work with robots where different follower arms have "
"different numbers of motors."
)
@RobotConfig.register_subclass("aloha")
@dataclass
class AlohaRobotConfig(ManipulatorRobotConfig):
# Specific to Aloha, LeRobot comes with default calibration files. Assuming the motors have been
# properly assembled, no manual calibration step is expected. If you need to run manual calibration,
# simply update this path to ".cache/calibration/aloha"
calibration_dir: str = ".cache/calibration/aloha_default"
# /!\ FOR SAFETY, READ THIS /!\
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
# When you feel more confident with teleoperation or running the policy, you can extend
# this safety limit and even removing it by setting it to `null`.
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
max_relative_target: int | None = 5
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
# window_x
port="/dev/ttyDXL_leader_left",
motors={
# name: (index, model)
"waist": [1, "xm430-w350"],
"shoulder": [2, "xm430-w350"],
"shoulder_shadow": [3, "xm430-w350"],
"elbow": [4, "xm430-w350"],
"elbow_shadow": [5, "xm430-w350"],
"forearm_roll": [6, "xm430-w350"],
"wrist_angle": [7, "xm430-w350"],
"wrist_rotate": [8, "xl430-w250"],
"gripper": [9, "xc430-w150"],
},
),
"right": DynamixelMotorsBusConfig(
# window_x
port="/dev/ttyDXL_leader_right",
motors={
# name: (index, model)
"waist": [1, "xm430-w350"],
"shoulder": [2, "xm430-w350"],
"shoulder_shadow": [3, "xm430-w350"],
"elbow": [4, "xm430-w350"],
"elbow_shadow": [5, "xm430-w350"],
"forearm_roll": [6, "xm430-w350"],
"wrist_angle": [7, "xm430-w350"],
"wrist_rotate": [8, "xl430-w250"],
"gripper": [9, "xc430-w150"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
port="/dev/ttyDXL_follower_left",
motors={
# name: (index, model)
"waist": [1, "xm540-w270"],
"shoulder": [2, "xm540-w270"],
"shoulder_shadow": [3, "xm540-w270"],
"elbow": [4, "xm540-w270"],
"elbow_shadow": [5, "xm540-w270"],
"forearm_roll": [6, "xm540-w270"],
"wrist_angle": [7, "xm540-w270"],
"wrist_rotate": [8, "xm430-w350"],
"gripper": [9, "xm430-w350"],
},
),
"right": DynamixelMotorsBusConfig(
port="/dev/ttyDXL_follower_right",
motors={
# name: (index, model)
"waist": [1, "xm540-w270"],
"shoulder": [2, "xm540-w270"],
"shoulder_shadow": [3, "xm540-w270"],
"elbow": [4, "xm540-w270"],
"elbow_shadow": [5, "xm540-w270"],
"forearm_roll": [6, "xm540-w270"],
"wrist_angle": [7, "xm540-w270"],
"wrist_rotate": [8, "xm430-w350"],
"gripper": [9, "xm430-w350"],
},
),
}
)
# Troubleshooting: If one of your IntelRealSense cameras freeze during
# data recording due to bandwidth limit, you might need to plug the camera
# on another USB hub or PCIe card.
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"cam_high": IntelRealSenseCameraConfig(
serial_number=128422271347,
fps=30,
width=640,
height=480,
),
"cam_low": IntelRealSenseCameraConfig(
serial_number=130322270656,
fps=30,
width=640,
height=480,
),
"cam_left_wrist": IntelRealSenseCameraConfig(
serial_number=218622272670,
fps=30,
width=640,
height=480,
),
"cam_right_wrist": IntelRealSenseCameraConfig(
serial_number=130322272300,
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("koch")
@dataclass
class KochRobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/koch"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0085511",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl330-m077"],
"shoulder_lift": [2, "xl330-m077"],
"elbow_flex": [3, "xl330-m077"],
"wrist_flex": [4, "xl330-m077"],
"wrist_roll": [5, "xl330-m077"],
"gripper": [6, "xl330-m077"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
"elbow_flex": [3, "xl330-m288"],
"wrist_flex": [4, "xl330-m288"],
"wrist_roll": [5, "xl330-m288"],
"gripper": [6, "xl330-m288"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
# ~ Koch specific settings ~
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
# to squeeze the gripper and have it spring back to an open position on its own.
gripper_open_degree: float = 35.156
mock: bool = False
@RobotConfig.register_subclass("koch_bimanual")
@dataclass
class KochBimanualRobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/koch_bimanual"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0085511",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl330-m077"],
"shoulder_lift": [2, "xl330-m077"],
"elbow_flex": [3, "xl330-m077"],
"wrist_flex": [4, "xl330-m077"],
"wrist_roll": [5, "xl330-m077"],
"gripper": [6, "xl330-m077"],
},
),
"right": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0031751",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl330-m077"],
"shoulder_lift": [2, "xl330-m077"],
"elbow_flex": [3, "xl330-m077"],
"wrist_flex": [4, "xl330-m077"],
"wrist_roll": [5, "xl330-m077"],
"gripper": [6, "xl330-m077"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
"elbow_flex": [3, "xl330-m288"],
"wrist_flex": [4, "xl330-m288"],
"wrist_roll": [5, "xl330-m288"],
"gripper": [6, "xl330-m288"],
},
),
"right": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0032081",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
"elbow_flex": [3, "xl330-m288"],
"wrist_flex": [4, "xl330-m288"],
"wrist_roll": [5, "xl330-m288"],
"gripper": [6, "xl330-m288"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
# ~ Koch specific settings ~
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
# to squeeze the gripper and have it spring back to an open position on its own.
gripper_open_degree: float = 35.156
mock: bool = False
@RobotConfig.register_subclass("moss")
@dataclass
class MossRobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/moss"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("so100")
@dataclass
class So100RobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/so100"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("stretch")
@dataclass
class StretchRobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"navigation": OpenCVCameraConfig(
camera_index="/dev/hello-nav-head-camera",
fps=10,
width=1280,
height=720,
rotation=-90,
),
"head": IntelRealSenseCameraConfig(
name="Intel RealSense D435I",
fps=30,
width=640,
height=480,
rotation=90,
),
"wrist": IntelRealSenseCameraConfig(
name="Intel RealSense D405",
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False

View File

@@ -1,9 +0,0 @@
import hydra
from omegaconf import DictConfig
from lerobot.common.robot_devices.robots.utils import Robot
def make_robot(cfg: DictConfig) -> Robot:
robot = hydra.utils.instantiate(cfg)
return robot

View File

@@ -8,15 +8,14 @@ import json
import logging
import time
import warnings
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import Sequence
import numpy as np
import torch
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
@@ -41,50 +40,6 @@ def ensure_safe_goal_position(
return safe_goal_pos
@dataclass
class ManipulatorRobotConfig:
"""
Example of usage:
```python
ManipulatorRobotConfig()
```
"""
# Define all components of the robot
robot_type: str = "koch"
leader_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
follower_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
cameras: dict[str, Camera] = field(default_factory=lambda: {})
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
# as the number of motors in your follower arms (assumes all follower arms have the same number of
# motors).
max_relative_target: list[float] | float | None = None
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
# gripper is not put in torque mode.
gripper_open_degree: float | None = None
def __setattr__(self, prop: str, val):
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(val):
raise ValueError(
f"len(max_relative_target)={len(val)} but the follower arm with name {name} has "
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
f"`max_relative_target` list has as many parameters as there are motors per arm. "
"Note: This feature does not yet work with robots where different follower arms have "
"different numbers of motors."
)
super().__setattr__(prop, val)
def __post_init__(self):
if self.robot_type not in ["koch", "koch_bimanual", "aloha", "so100", "moss"]:
raise ValueError(f"Provided robot type ({self.robot_type}) is not supported.")
class ManipulatorRobot:
# TODO(rcadene): Implement force feedback
"""This class allows to control any manipulator robot of various number of motors.
@@ -95,11 +50,16 @@ class ManipulatorRobot:
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
- [Aloha](https://www.trossenrobotics.com/aloha-kits) developed by Trossen Robotics
Example of highest frequency teleoperation without camera:
Example of instantiation, a pre-defined robot config is required:
```python
robot = ManipulatorRobot(KochRobotConfig())
```
Example of overwritting motors during instantiation:
```python
# Defines how to communicate with the motors of the leader and follower arms
leader_arms = {
"main": DynamixelMotorsBus(
"main": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0031751",
motors={
# name: (index, model)
@@ -113,7 +73,7 @@ class ManipulatorRobot:
),
}
follower_arms = {
"main": DynamixelMotorsBus(
"main": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0032081",
motors={
# name: (index, model)
@@ -126,35 +86,11 @@ class ManipulatorRobot:
},
),
}
robot = ManipulatorRobot(
robot_type="koch",
calibration_dir=".cache/calibration/koch",
leader_arms=leader_arms,
follower_arms=follower_arms,
)
# Connect motors buses and cameras if any (Required)
robot.connect()
while True:
robot.teleop_step()
robot_config = KochRobotConfig(leader_arms=leader_arms, follower_arms=follower_arms)
robot = ManipulatorRobot(robot_config)
```
Example of highest frequency data collection without camera:
```python
# Assumes leader and follower arms have been instantiated already (see first example)
robot = ManipulatorRobot(
robot_type="koch",
calibration_dir=".cache/calibration/koch",
leader_arms=leader_arms,
follower_arms=follower_arms,
)
robot.connect()
while True:
observation, action = robot.teleop_step(record_data=True)
```
Example of highest frequency data collection with cameras:
Example of overwritting cameras during instantiation:
```python
# Defines how to communicate with 2 cameras connected to the computer.
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
@@ -164,31 +100,28 @@ class ManipulatorRobot:
"laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
"phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
}
robot = ManipulatorRobot(KochRobotConfig(cameras=cameras))
```
# Assumes leader and follower arms have been instantiated already (see first example)
robot = ManipulatorRobot(
robot_type="koch",
calibration_dir=".cache/calibration/koch",
leader_arms=leader_arms,
follower_arms=follower_arms,
cameras=cameras,
)
Once the robot is instantiated, connect motors buses and cameras if any (Required):
```python
robot.connect()
```
Example of highest frequency teleoperation, which doesn't require cameras:
```python
while True:
robot.teleop_step()
```
Example of highest frequency data collection from motors and cameras (if any):
```python
while True:
observation, action = robot.teleop_step(record_data=True)
```
Example of controlling the robot with a policy (without running multiple policies in parallel to ensure highest frequency):
Example of controlling the robot with a policy:
```python
# Assumes leader and follower arms + cameras have been instantiated already (see previous example)
robot = ManipulatorRobot(
robot_type="koch",
calibration_dir=".cache/calibration/koch",
leader_arms=leader_arms,
follower_arms=follower_arms,
cameras=cameras,
)
robot.connect()
while True:
# Uses the follower arms and cameras to capture an observation
observation = robot.capture_observation()
@@ -209,20 +142,14 @@ class ManipulatorRobot:
def __init__(
self,
config: ManipulatorRobotConfig | None = None,
calibration_dir: Path = ".cache/calibration/koch",
**kwargs,
config: ManipulatorRobotConfig,
):
if config is None:
config = ManipulatorRobotConfig()
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
self.calibration_dir = Path(calibration_dir)
self.robot_type = self.config.robot_type
self.leader_arms = self.config.leader_arms
self.follower_arms = self.config.follower_arms
self.cameras = self.config.cameras
self.config = config
self.robot_type = self.config.type
self.calibration_dir = Path(self.config.calibration_dir)
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
self.cameras = make_cameras_from_configs(self.config.cameras)
self.is_connected = False
self.logs = {}

View File

@@ -15,23 +15,14 @@
# limitations under the License.
import time
from dataclasses import dataclass, field, replace
from dataclasses import replace
import torch
from stretch_body.gamepad_teleop import GamePadTeleop
from stretch_body.robot import Robot as StretchAPI
from stretch_body.robot_params import RobotParams
from lerobot.common.robot_devices.cameras.utils import Camera
@dataclass
class StretchRobotConfig:
robot_type: str | None = "stretch"
cameras: dict[str, Camera] = field(default_factory=lambda: {})
# TODO(aliberts): add feature with max_relative target
# TODO(aliberts): add comment on max_relative target
max_relative_target: list[float] | float | None = None
from lerobot.common.robot_devices.robots.configs import StretchRobotConfig
class StretchRobot(StretchAPI):
@@ -40,11 +31,12 @@ class StretchRobot(StretchAPI):
def __init__(self, config: StretchRobotConfig | None = None, **kwargs):
super().__init__()
if config is None:
config = StretchRobotConfig()
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
self.config = StretchRobotConfig(**kwargs)
else:
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
self.robot_type = self.config.robot_type
self.robot_type = self.config.type
self.cameras = self.config.cameras
self.is_connected = False
self.teleop = None

View File

@@ -1,5 +1,16 @@
from typing import Protocol
from lerobot.common.robot_devices.robots.configs import (
AlohaRobotConfig,
KochBimanualRobotConfig,
KochRobotConfig,
ManipulatorRobotConfig,
MossRobotConfig,
RobotConfig,
So100RobotConfig,
StretchRobotConfig,
)
def get_arm_id(name, arm_type):
"""Returns the string identifier of a robot arm. For instance, for a bimanual manipulator
@@ -19,3 +30,36 @@ class Robot(Protocol):
def capture_observation(self): ...
def send_action(self, action): ...
def disconnect(self): ...
def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
if robot_type == "aloha":
return AlohaRobotConfig(**kwargs)
elif robot_type == "koch":
return KochRobotConfig(**kwargs)
elif robot_type == "koch_bimanual":
return KochBimanualRobotConfig(**kwargs)
elif robot_type == "moss":
return MossRobotConfig(**kwargs)
elif robot_type == "so100":
return So100RobotConfig(**kwargs)
elif robot_type == "stretch":
return StretchRobotConfig(**kwargs)
else:
raise ValueError(f"Robot type '{robot_type}' is not available.")
def make_robot_from_config(config: RobotConfig):
if isinstance(config, ManipulatorRobotConfig):
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
return ManipulatorRobot(config)
else:
from lerobot.common.robot_devices.robots.stretch import StretchRobot
return StretchRobot(config)
def make_robot(robot_type: str, **kwargs) -> Robot:
config = make_robot_config(robot_type, **kwargs)
return make_robot_from_config(config)

188
lerobot/common/utils/hub.py Normal file
View File

@@ -0,0 +1,188 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Type, TypeVar
from huggingface_hub import HfApi
from huggingface_hub.utils import validate_hf_hub_args
T = TypeVar("T", bound="HubMixin")
class HubMixin:
"""
A Mixin containing the functionality to push an object to the hub.
This is similar to huggingface_hub.ModelHubMixin but is lighter and makes less assumptions about its
subclasses (in particular, the fact that it's not necessarily a model).
The inheriting classes must implement '_save_pretrained' and 'from_pretrained'.
"""
def save_pretrained(
self,
save_directory: str | Path,
*,
repo_id: str | None = None,
push_to_hub: bool = False,
card_kwargs: dict[str, Any] | None = None,
**push_to_hub_kwargs,
) -> str | None:
"""
Save object in local directory.
Args:
save_directory (`str` or `Path`):
Path to directory in which the object will be saved.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your object to the Huggingface Hub after saving it.
repo_id (`str`, *optional*):
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
not provided.
card_kwargs (`Dict[str, Any]`, *optional*):
Additional arguments passed to the card template to customize the card.
push_to_hub_kwargs:
Additional key word arguments passed along to the [`~HubMixin.push_to_hub`] method.
Returns:
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
# save object (weights, files, etc.)
self._save_pretrained(save_directory)
# push to the Hub if required
if push_to_hub:
if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
return None
def _save_pretrained(self, save_directory: Path) -> None:
"""
Overwrite this method in subclass to define how to save your object.
Args:
save_directory (`str` or `Path`):
Path to directory in which the object files will be saved.
"""
raise NotImplementedError
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls: Type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**kwargs,
) -> T:
"""
Download the object from the Huggingface Hub and instantiate it.
Args:
pretrained_name_or_path (`str`, `Path`):
- Either the `repo_id` (string) of the object hosted on the Hub, e.g. `lerobot/diffusion_pusht`.
- Or a path to a `directory` containing the object files saved using `.save_pretrained`,
e.g., `../path/to/my_model_directory/`.
revision (`str`, *optional*):
Revision on the Hub. Can be a branch name, a git tag or any commit id.
Defaults to the latest commit on `main` branch.
force_download (`bool`, *optional*, defaults to `False`):
Whether to force (re-)downloading the files from the Hub, overriding the existing cache.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running `huggingface-cli login`.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
kwargs (`Dict`, *optional*):
Additional kwargs to pass to the object during initialization.
"""
raise NotImplementedError
@validate_hf_hub_args
def push_to_hub(
self,
repo_id: str,
*,
commit_message: str | None = None,
private: bool | None = None,
token: str | None = None,
branch: str | None = None,
create_pr: bool | None = None,
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
delete_patterns: list[str] | str | None = None,
card_kwargs: dict[str, Any] | None = None,
) -> str:
"""
Upload model checkpoint to the Hub.
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
details.
Args:
repo_id (`str`):
ID of the repository to push to (example: `"username/my-model"`).
commit_message (`str`, *optional*):
Message to commit while pushing.
private (`bool`, *optional*):
Whether the repository created should be private.
If `None` (default), the repo will be public unless the organization's default is private.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running `huggingface-cli login`.
branch (`str`, *optional*):
The git branch on which to push the model. This defaults to `"main"`.
create_pr (`boolean`, *optional*):
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
allow_patterns (`List[str]` or `str`, *optional*):
If provided, only files matching at least one pattern are pushed.
ignore_patterns (`List[str]` or `str`, *optional*):
If provided, files matching any of the patterns are not pushed.
delete_patterns (`List[str]` or `str`, *optional*):
If provided, remote files matching any of the patterns will be deleted from the repo.
card_kwargs (`Dict[str, Any]`, *optional*):
Additional arguments passed to the card template to customize the card.
Returns:
The url of the commit of your object in the given repository.
"""
api = HfApi(token=token)
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
if commit_message is None:
if "Policy" in self.__class__.__name__:
commit_message = "Upload policy"
elif "Config" in self.__class__.__name__:
commit_message = "Upload config"
else:
commit_message = f"Upload {self.__class__.__name__}"
# Push the files to the repo in a single commit
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path, card_kwargs=card_kwargs)
return api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=saved_path,
commit_message=commit_message,
revision=branch,
create_pr=create_pr,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
delete_patterns=delete_patterns,
)

View File

@@ -19,14 +19,13 @@ import os.path as osp
import platform
import random
from contextlib import contextmanager
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Generator
import hydra
import numpy as np
import torch
from omegaconf import DictConfig
def none_or_int(value):
@@ -41,9 +40,22 @@ def inside_slurm():
return "SLURM_JOB_ID" in os.environ
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
def auto_select_torch_device() -> torch.device:
"""Tries to select automatically a torch device."""
if torch.cuda.is_available():
logging.info("Cuda backend detected, using cuda.")
return torch.device("cuda")
elif torch.backends.mps.is_available():
logging.info("Metal backend detected, using cuda.")
return torch.device("mps")
else:
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
return torch.device("cpu")
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
match cfg_device:
match try_device:
case "cuda":
assert torch.cuda.is_available()
device = torch.device("cuda")
@@ -55,13 +67,33 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
if log:
logging.warning("Using CPU, this will be slow.")
case _:
device = torch.device(cfg_device)
device = torch.device(try_device)
if log:
logging.warning(f"Using custom {cfg_device} device.")
logging.warning(f"Using custom {try_device} device.")
return device
def is_torch_device_available(try_device: str) -> bool:
if try_device == "cuda":
return torch.cuda.is_available()
elif try_device == "mps":
return torch.backends.mps.is_available()
elif try_device == "cpu":
return True
else:
raise ValueError(f"Unknown device '{try_device}.")
def is_amp_available(device: str):
if device in ["cuda", "cpu"]:
return True
elif device == "mps":
return False
else:
raise ValueError(f"Unknown device '{device}.")
def get_global_random_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
random_state_dict = {
@@ -159,22 +191,6 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
)
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
"""Initialize a Hydra config given only the path to the relevant config file.
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
"""
# TODO(alexander-soare): Resolve configs without Hydra initialization.
hydra.core.global_hydra.GlobalHydra.instance().clear()
# Hydra needs a path relative to this file.
hydra.initialize(
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)),
version_base="1.2",
)
cfg = hydra.compose(Path(config_path).stem, overrides)
return cfg
def print_cuda_memory_usage():
"""Use this function to locate and debug memory leak."""
import gc
@@ -217,3 +233,17 @@ def log_say(text, play_sounds, blocking=False):
if play_sounds:
say(text, blocking)
def get_channel_first_image_shape(image_shape: tuple) -> tuple:
shape = copy(image_shape)
if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif not (shape[0] < shape[1] and shape[0] < shape[2]):
raise ValueError(image_shape)
return shape
def has_method(cls: object, method_name: str):
return hasattr(cls, method_name) and callable(getattr(cls, method_name))