forked from tangger/lerobot
merge main
This commit is contained in:
@@ -58,7 +58,6 @@ available_tasks_per_env = {
|
||||
],
|
||||
"pusht": ["PushT-v0"],
|
||||
"xarm": ["XarmLift-v0"],
|
||||
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||
}
|
||||
available_envs = list(available_tasks_per_env.keys())
|
||||
|
||||
@@ -86,23 +85,6 @@ available_datasets_per_env = {
|
||||
"lerobot/xarm_push_medium_image",
|
||||
"lerobot/xarm_push_medium_replay_image",
|
||||
],
|
||||
"dora_aloha_real": [
|
||||
"lerobot/aloha_static_battery",
|
||||
"lerobot/aloha_static_candy",
|
||||
"lerobot/aloha_static_coffee",
|
||||
"lerobot/aloha_static_coffee_new",
|
||||
"lerobot/aloha_static_cups_open",
|
||||
"lerobot/aloha_static_fork_pick_up",
|
||||
"lerobot/aloha_static_pingpong_test",
|
||||
"lerobot/aloha_static_pro_pencil",
|
||||
"lerobot/aloha_static_screw_driver",
|
||||
"lerobot/aloha_static_tape",
|
||||
"lerobot/aloha_static_thread_velcro",
|
||||
"lerobot/aloha_static_towel",
|
||||
"lerobot/aloha_static_vinh_cup",
|
||||
"lerobot/aloha_static_vinh_cup_left",
|
||||
"lerobot/aloha_static_ziploc_slide",
|
||||
],
|
||||
}
|
||||
|
||||
available_real_world_datasets = [
|
||||
@@ -221,7 +203,6 @@ available_policies_per_env = {
|
||||
"xarm": ["tdmpc"],
|
||||
"koch_real": ["act_koch_real"],
|
||||
"aloha_real": ["act_aloha_real"],
|
||||
"dora_aloha_real": ["act_aloha_real"],
|
||||
}
|
||||
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
|
||||
6
lerobot/common/constants.py
Normal file
6
lerobot/common/constants.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# keys
|
||||
OBS_ENV = "observation.environment_state"
|
||||
OBS_ROBOT = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
@@ -14,103 +14,102 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.transforms import ImageTransforms
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
|
||||
}
|
||||
|
||||
|
||||
def resolve_delta_timestamps(cfg):
|
||||
"""Resolves delta_timestamps config key (in-place) by using `eval`.
|
||||
def resolve_delta_timestamps(
|
||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
||||
) -> dict[str, list] | None:
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
||||
|
||||
Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by
|
||||
the data type of its values).
|
||||
"""
|
||||
delta_timestamps = cfg.training.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
# TODO(rcadene, alexander-soare): remove `eval` to avoid exploit
|
||||
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
|
||||
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""
|
||||
Args:
|
||||
cfg: A Hydra config as per the LeRobot config scheme.
|
||||
split: Select the data subset used to create an instance of LeRobotDataset.
|
||||
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
|
||||
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
|
||||
slicer in the hugging face datasets:
|
||||
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
|
||||
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
|
||||
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
||||
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
||||
delta_timestamps against.
|
||||
|
||||
Returns:
|
||||
The LeRobotDataset.
|
||||
dict[str, list] | None: A dictionary of delta_timestamps, e.g.:
|
||||
{
|
||||
"observation.state": [-0.04, -0.02, 0]
|
||||
"observation.action": [-0.02, 0, 0.02]
|
||||
}
|
||||
returns `None` if the the resulting dict is empty.
|
||||
"""
|
||||
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
|
||||
raise ValueError(
|
||||
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
|
||||
"strings to load multiple datasets."
|
||||
)
|
||||
delta_timestamps = {}
|
||||
for key in ds_meta.features:
|
||||
if key == "next.reward" and cfg.reward_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||
if cfg.env.name != "dora":
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||
else:
|
||||
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
|
||||
if len(delta_timestamps) == 0:
|
||||
delta_timestamps = None
|
||||
|
||||
for dataset_repo_id in dataset_repo_ids:
|
||||
if cfg.env.name not in dataset_repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
|
||||
f"environment ({cfg.env.name=})."
|
||||
)
|
||||
return delta_timestamps
|
||||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
image_transforms = None
|
||||
if cfg.training.image_transforms.enable:
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
# TODO (aliberts): add 'episodes' arg from config after removing hydra
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset | MultiLeRobotDataset
|
||||
"""
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
cfg.dataset.repo_id,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
local_files_only=cfg.dataset.local_files_only,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
cfg.dataset.repo_id,
|
||||
# TODO(aliberts): add proper support for multi dataset
|
||||
# delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(dataset.repo_id_to_index , indent=2)}"
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -304,6 +304,13 @@ class LeRobotDatasetMetadata:
|
||||
)
|
||||
else:
|
||||
# TODO(aliberts, rcadene): implement sanity check for features
|
||||
|
||||
# check if none of the features contains a "/" in their names,
|
||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
for key in features:
|
||||
if "/" in key:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
||||
@@ -665,6 +672,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
@@ -833,7 +844,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def stop_image_writer(self) -> None:
|
||||
"""
|
||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.stop()
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
@@ -65,6 +66,8 @@ class RandomSubsetApply(Transform):
|
||||
self.n_subset = n_subset
|
||||
self.random_order = random_order
|
||||
|
||||
self.selected_transforms = None
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
needs_unpacking = len(inputs) > 1
|
||||
|
||||
@@ -72,9 +75,9 @@ class RandomSubsetApply(Transform):
|
||||
if not self.random_order:
|
||||
selected_indices = selected_indices.sort().values
|
||||
|
||||
selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
self.selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
|
||||
for transform in selected_transforms:
|
||||
for transform in self.selected_transforms:
|
||||
outputs = transform(*inputs)
|
||||
inputs = outputs if needs_unpacking else (outputs,)
|
||||
|
||||
@@ -129,69 +132,118 @@ class SharpnessJitter(Transform):
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def _generate_value(self, left: float, right: float) -> float:
|
||||
return torch.empty(1).uniform_(left, right).item()
|
||||
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
|
||||
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
return {"sharpness_factor": sharpness_factor}
|
||||
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
||||
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
||||
sharpness_factor = params["sharpness_factor"]
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def get_image_transforms(
|
||||
brightness_weight: float = 1.0,
|
||||
brightness_min_max: tuple[float, float] | None = None,
|
||||
contrast_weight: float = 1.0,
|
||||
contrast_min_max: tuple[float, float] | None = None,
|
||||
saturation_weight: float = 1.0,
|
||||
saturation_min_max: tuple[float, float] | None = None,
|
||||
hue_weight: float = 1.0,
|
||||
hue_min_max: tuple[float, float] | None = None,
|
||||
sharpness_weight: float = 1.0,
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
if len(min_max) != 2:
|
||||
raise ValueError(
|
||||
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
|
||||
)
|
||||
if weight < 0.0:
|
||||
raise ValueError(
|
||||
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
|
||||
)
|
||||
@dataclass
|
||||
class ImageTransformConfig:
|
||||
"""
|
||||
For each transform, the following parameters are available:
|
||||
weight: This represents the multinomial probability (with no replacement)
|
||||
used for sampling the transform. If the sum of the weights is not 1,
|
||||
they will be normalized.
|
||||
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
|
||||
custom transform defined here.
|
||||
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
|
||||
(following uniform distribution) when it's applied.
|
||||
"""
|
||||
|
||||
check_value("brightness", brightness_weight, brightness_min_max)
|
||||
check_value("contrast", contrast_weight, contrast_min_max)
|
||||
check_value("saturation", saturation_weight, saturation_min_max)
|
||||
check_value("hue", hue_weight, hue_min_max)
|
||||
check_value("sharpness", sharpness_weight, sharpness_min_max)
|
||||
weight: float = 1.0
|
||||
type: str = "Identity"
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
if contrast_min_max is not None and contrast_weight > 0.0:
|
||||
weights.append(contrast_weight)
|
||||
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||
if saturation_min_max is not None and saturation_weight > 0.0:
|
||||
weights.append(saturation_weight)
|
||||
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||
if hue_min_max is not None and hue_weight > 0.0:
|
||||
weights.append(hue_weight)
|
||||
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
n_subset = min(n_subset, max_num_transforms)
|
||||
@dataclass
|
||||
class ImageTransformsConfig:
|
||||
"""
|
||||
These transforms are all using standard torchvision.transforms.v2
|
||||
You can find out how these transformations affect images here:
|
||||
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
||||
We use a custom RandomSubsetApply container to sample them.
|
||||
"""
|
||||
|
||||
if n_subset == 0:
|
||||
return v2.Identity()
|
||||
# Set this flag to `true` to enable transforms during training
|
||||
enable: bool = False
|
||||
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
|
||||
# It's an integer in the interval [1, number_of_available_transforms].
|
||||
max_num_transforms: int = 3
|
||||
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
||||
# Set this to True to apply them in a random order.
|
||||
random_order: bool = False
|
||||
tfs: dict[str, ImageTransformConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.8, 1.2)},
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.8, 1.2)},
|
||||
),
|
||||
"saturation": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"saturation": (0.5, 1.5)},
|
||||
),
|
||||
"hue": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"hue": (-0.05, 0.05)},
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 1.5)},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
if cfg.type == "Identity":
|
||||
return v2.Identity(**cfg.kwargs)
|
||||
elif cfg.type == "ColorJitter":
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
else:
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
class ImageTransforms(Transform):
|
||||
"""A class to compose image transforms based on configuration."""
|
||||
|
||||
def __init__(self, cfg: ImageTransformsConfig) -> None:
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
|
||||
self.weights = []
|
||||
self.transforms = {}
|
||||
for tf_name, tf_cfg in cfg.tfs.items():
|
||||
if tf_cfg.weight <= 0.0:
|
||||
continue
|
||||
|
||||
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
|
||||
self.weights.append(tf_cfg.weight)
|
||||
|
||||
n_subset = min(len(self.transforms), cfg.max_num_transforms)
|
||||
if n_subset == 0 or not cfg.enable:
|
||||
self.tf = v2.Identity()
|
||||
else:
|
||||
self.tf = RandomSubsetApply(
|
||||
transforms=list(self.transforms.values()),
|
||||
p=self.weights,
|
||||
n_subset=n_subset,
|
||||
random_order=cfg.random_order,
|
||||
)
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
return self.tf(*inputs)
|
||||
|
||||
@@ -35,6 +35,7 @@ from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
|
||||
@@ -98,6 +99,18 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
return outdict
|
||||
|
||||
|
||||
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
||||
split_keys = flattened_key.split(sep)
|
||||
getter = obj[split_keys[0]]
|
||||
if len(split_keys) == 1:
|
||||
return getter
|
||||
|
||||
for key in split_keys[1:]:
|
||||
getter = getter[key]
|
||||
|
||||
return getter
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(serialized_dict)
|
||||
@@ -289,6 +302,37 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
||||
|
||||
|
||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
||||
policy_features = {}
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
if ft["dtype"] in ["image", "video"]:
|
||||
type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith("observation"):
|
||||
type = FeatureType.STATE
|
||||
elif key == "action":
|
||||
type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
|
||||
policy_features[key] = PolicyFeature(
|
||||
type=type,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
return policy_features
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
@@ -436,7 +480,7 @@ def check_delta_timestamps(
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
|
||||
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
||||
|
||||
return delta_indices
|
||||
|
||||
|
||||
@@ -26,13 +26,13 @@ from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
|
||||
from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml")
|
||||
ALOHA_MOBILE_INFO = {
|
||||
"robot_config": parse_robot_config(ALOHA_CONFIG),
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"license": "mit",
|
||||
"url": "https://mobile-aloha.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2401.02117",
|
||||
@@ -45,7 +45,7 @@ ALOHA_MOBILE_INFO = {
|
||||
}""").lstrip(),
|
||||
}
|
||||
ALOHA_STATIC_INFO = {
|
||||
"robot_config": parse_robot_config(ALOHA_CONFIG),
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"license": "mit",
|
||||
"url": "https://tonyzhaozh.github.io/aloha/",
|
||||
"paper": "https://arxiv.org/abs/2304.13705",
|
||||
|
||||
@@ -141,7 +141,8 @@ from lerobot.common.datasets.video_utils import (
|
||||
get_image_pixel_channels,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_config
|
||||
|
||||
V16 = "v1.6"
|
||||
V20 = "v2.0"
|
||||
@@ -152,19 +153,18 @@ V1_INFO_PATH = "meta_data/info.json"
|
||||
V1_STATS_PATH = "meta_data/stats.safetensors"
|
||||
|
||||
|
||||
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
||||
robot_cfg = init_hydra_config(config_path, config_overrides)
|
||||
if robot_cfg["robot_type"] in ["aloha", "koch"]:
|
||||
def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
|
||||
if robot_cfg.type in ["aloha", "koch"]:
|
||||
state_names = [
|
||||
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["follower_arms"]
|
||||
for motor in robot_cfg["follower_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
|
||||
for arm in robot_cfg.follower_arms
|
||||
for motor in robot_cfg.follower_arms[arm].motors
|
||||
]
|
||||
action_names = [
|
||||
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["leader_arms"]
|
||||
for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
|
||||
for arm in robot_cfg.leader_arms
|
||||
for motor in robot_cfg.leader_arms[arm].motors
|
||||
]
|
||||
# elif robot_cfg["robot_type"] == "stretch3": TODO
|
||||
else:
|
||||
@@ -173,7 +173,7 @@ def parse_robot_config(config_path: Path, config_overrides: list[str] | None = N
|
||||
)
|
||||
|
||||
return {
|
||||
"robot_type": robot_cfg["robot_type"],
|
||||
"robot_type": robot_cfg.type,
|
||||
"names": {
|
||||
"observation.state": state_names,
|
||||
"observation.effort": state_names,
|
||||
@@ -203,7 +203,10 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
torch.testing.assert_close(stats_json[key], stats[key])
|
||||
|
||||
|
||||
def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]:
|
||||
def get_features_from_hf_dataset(
|
||||
dataset: Dataset, robot_config: RobotConfig | None = None
|
||||
) -> dict[str, list]:
|
||||
robot_config = parse_robot_config(robot_config)
|
||||
features = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Value):
|
||||
@@ -224,11 +227,11 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
|
||||
image = dataset[0][key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
shape = (image.height, image.width, channels)
|
||||
names = ["height", "width", "channel"]
|
||||
names = ["height", "width", "channels"]
|
||||
elif ft._type == "VideoFrame":
|
||||
dtype = "video"
|
||||
shape = None # Add shape later
|
||||
names = ["height", "width", "channel"]
|
||||
names = ["height", "width", "channels"]
|
||||
|
||||
features[key] = {
|
||||
"dtype": dtype,
|
||||
@@ -436,7 +439,7 @@ def convert_dataset(
|
||||
single_task: str | None = None,
|
||||
tasks_path: Path | None = None,
|
||||
tasks_col: Path | None = None,
|
||||
robot_config: dict | None = None,
|
||||
robot_config: RobotConfig | None = None,
|
||||
test_branch: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
@@ -532,7 +535,7 @@ def convert_dataset(
|
||||
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config["robot_type"]
|
||||
robot_type = robot_config.type
|
||||
repo_tags = [robot_type]
|
||||
else:
|
||||
robot_type = "unknown"
|
||||
@@ -621,16 +624,10 @@ def main():
|
||||
help="The path to a .json file containing one language instruction for each episode_index",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-config",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Path to the robot's config yaml the dataset during conversion.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-overrides",
|
||||
"--robot",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)",
|
||||
default=None,
|
||||
help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
@@ -655,8 +652,10 @@ def main():
|
||||
if not args.local_dir:
|
||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||
|
||||
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
|
||||
del args.robot_config, args.robot_overrides
|
||||
if args.robot is not None:
|
||||
robot_config = make_robot_config(args.robot)
|
||||
|
||||
del args.robot
|
||||
|
||||
convert_dataset(**vars(args), robot_config=robot_config)
|
||||
|
||||
|
||||
1
lerobot/common/envs/__init__.py
Normal file
1
lerobot/common/envs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
||||
142
lerobot/common/envs/configs.py
Normal file
142
lerobot/common/envs/configs.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
task: str | None = None
|
||||
fps: int = 30
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
features_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractproperty
|
||||
def gym_kwargs(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("aloha")
|
||||
@dataclass
|
||||
class AlohaEnv(EnvConfig):
|
||||
task: str = "AlohaInsertion-v0"
|
||||
fps: int = 50
|
||||
episode_length: int = 400
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"top": f"{OBS_IMAGE}.top",
|
||||
"pixels/top": f"{OBS_IMAGES}.top",
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("pusht")
|
||||
@dataclass
|
||||
class PushtEnv(EnvConfig):
|
||||
task: str = "PushT-v0"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"environment_state": OBS_ENV,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
||||
elif self.obs_type == "environment_state_agent_pos":
|
||||
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("xarm")
|
||||
@dataclass
|
||||
class XarmEnv(EnvConfig):
|
||||
task: str = "XarmLift-v0"
|
||||
fps: int = 15
|
||||
episode_length: int = 200
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
@@ -16,43 +16,54 @@
|
||||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the evaluation config.
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
if env_type == "aloha":
|
||||
return AlohaEnv(**kwargs)
|
||||
elif env_type == "pusht":
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
||||
|
||||
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the config.
|
||||
|
||||
Args:
|
||||
cfg (EnvConfig): the config of the environment to instantiate.
|
||||
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
||||
use_async_envs (bool, optional): Wether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
||||
False.
|
||||
|
||||
Raises:
|
||||
ValueError: if n_envs < 1
|
||||
ModuleNotFoundError: If the requested env package is not intalled
|
||||
|
||||
Returns:
|
||||
gym.vector.VectorEnv: The parallelized gym.env instance.
|
||||
"""
|
||||
if n_envs is not None and n_envs < 1:
|
||||
if n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
package_name = f"gym_{cfg.type}"
|
||||
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
|
||||
)
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||
gym_kwgs = dict(cfg.env.get("gym", {}))
|
||||
|
||||
if cfg.env.get("episode_length"):
|
||||
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
]
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
@@ -18,8 +18,13 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.utils.utils import get_channel_first_image_shape
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
@@ -35,6 +40,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# sanity check that images are channel last
|
||||
@@ -60,3 +66,23 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return return_observations
|
||||
|
||||
|
||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||
policy_features = {}
|
||||
for key, ft in env_cfg.features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
if len(ft.shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
|
||||
|
||||
shape = get_channel_first_image_shape(ft.shape)
|
||||
feature = PolicyFeature(type=ft.type, shape=shape)
|
||||
else:
|
||||
feature = ft
|
||||
|
||||
policy_key = env_cfg.features_map[key]
|
||||
policy_features[policy_key] = feature
|
||||
|
||||
return policy_features
|
||||
|
||||
@@ -21,32 +21,39 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import get_global_random_state
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode
|
||||
|
||||
PRETRAINED_MODEL = "pretrained_model"
|
||||
TRAINING_STATE = "training_state.pth"
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.name}",
|
||||
f"dataset:{cfg.dataset_repo_id}",
|
||||
f"env:{cfg.env.name}",
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"dataset:{cfg.dataset.repo_id}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
@@ -68,7 +75,6 @@ class Logger:
|
||||
The logger creates the following directory structure:
|
||||
|
||||
provided_log_dir
|
||||
├── .hydra # hydra's configuration cache
|
||||
├── checkpoints
|
||||
│ ├── specific_checkpoint_name
|
||||
│ │ ├── pretrained_model # Hugging Face pretrained model directory
|
||||
@@ -80,28 +86,21 @@ class Logger:
|
||||
│ └── last # a softlink to the last logged checkpoint
|
||||
"""
|
||||
|
||||
pretrained_model_dir_name = "pretrained_model"
|
||||
training_state_file_name = "training_state.pth"
|
||||
pretrained_model_dir_name = PRETRAINED_MODEL
|
||||
training_state_file_name = TRAINING_STATE
|
||||
|
||||
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
log_dir: The directory to save all logs and training outputs to.
|
||||
job_name: The WandB job name.
|
||||
"""
|
||||
def __init__(self, cfg: TrainPipelineConfig):
|
||||
self._cfg = cfg
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir = cfg.output_dir
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
|
||||
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
|
||||
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
|
||||
self.job_name = cfg.job_name
|
||||
self.checkpoints_dir = self.get_checkpoints_dir(self.log_dir)
|
||||
self.last_checkpoint_dir = self.get_last_checkpoint_dir(self.log_dir)
|
||||
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(self.log_dir)
|
||||
|
||||
# Set up WandB.
|
||||
self._group = cfg_to_group(cfg)
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project
|
||||
run_offline = not cfg.wandb.enable or not cfg.wandb.project
|
||||
if run_offline:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
@@ -115,13 +114,13 @@ class Logger:
|
||||
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=project,
|
||||
entity=entity,
|
||||
name=wandb_job_name,
|
||||
notes=cfg.get("wandb", {}).get("notes"),
|
||||
project=cfg.wandb.project,
|
||||
entity=cfg.wandb.entity,
|
||||
name=self.job_name,
|
||||
notes=cfg.wandb.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=log_dir,
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
dir=self.log_dir,
|
||||
config=asdict(self._cfg),
|
||||
# TODO(rcadene): try set to True
|
||||
save_code=False,
|
||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
@@ -150,17 +149,19 @@ class Logger:
|
||||
"""
|
||||
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||
|
||||
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
|
||||
def save_model(self, save_dir: Path, policy: PreTrainedPolicy, wandb_artifact_name: str | None = None):
|
||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||
|
||||
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
|
||||
|
||||
Optionally also upload the model to WandB.
|
||||
"""
|
||||
|
||||
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
register_features_types()
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
# Also save the full config for the env configuration.
|
||||
self._cfg.save_pretrained(save_dir)
|
||||
if self._wandb and not self._cfg.wandb.disable_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
|
||||
@@ -173,18 +174,18 @@ class Logger:
|
||||
self,
|
||||
save_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
):
|
||||
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
|
||||
|
||||
All of these are saved as "training_state.pth" under the checkpoint directory.
|
||||
"""
|
||||
training_state = {
|
||||
"step": train_step,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
**get_global_random_state(),
|
||||
}
|
||||
training_state = {}
|
||||
training_state["step"] = train_step
|
||||
training_state.update(get_global_random_state())
|
||||
if optimizer is not None:
|
||||
training_state["optimizer"] = optimizer.state_dict()
|
||||
if scheduler is not None:
|
||||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
@@ -192,10 +193,10 @@ class Logger:
|
||||
def save_checkpoint(
|
||||
self,
|
||||
train_step: int,
|
||||
policy: Policy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
identifier: str,
|
||||
policy: PreTrainedPolicy,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
):
|
||||
"""Checkpoint the model weights and the training state."""
|
||||
checkpoint_dir = self.checkpoints_dir / str(identifier)
|
||||
@@ -208,26 +209,11 @@ class Logger:
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
|
||||
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
|
||||
|
||||
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||
"""
|
||||
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
|
||||
random state, and return the global training step.
|
||||
"""
|
||||
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
raise ValueError(
|
||||
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
|
||||
)
|
||||
# Small hack to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"]
|
||||
relative_target = checkpoint_dir.relative_to(self.last_checkpoint_dir.parent)
|
||||
self.last_checkpoint_dir.symlink_to(relative_target)
|
||||
|
||||
def log_dict(self, d, step, mode="train"):
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if self._wandb is not None:
|
||||
@@ -242,5 +228,13 @@ class Logger:
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
assert self._wandb is not None
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.env.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
||||
|
||||
def register_features_types():
|
||||
draccus.decode.register(FeatureType, lambda x: FeatureType[x])
|
||||
draccus.encode.register(FeatureType, lambda x: x.name)
|
||||
|
||||
draccus.decode.register(NormalizationMode, lambda x: NormalizationMode[x])
|
||||
draccus.encode.register(NormalizationMode, lambda x: x.name)
|
||||
|
||||
1
lerobot/common/optim/__init__.py
Normal file
1
lerobot/common/optim/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .optimizers import OptimizerConfig as OptimizerConfig
|
||||
61
lerobot/common/optim/factory.py
Normal file
61
lerobot/common/optim/factory.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.logger import TRAINING_STATE
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def make_optimizer_and_scheduler(
|
||||
cfg: TrainPipelineConfig, policy: PreTrainedPolicy
|
||||
) -> tuple[Optimizer, LRScheduler | None]:
|
||||
"""Generates the optimizer and scheduler based on configs.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training config that contains optimizer and scheduler configs
|
||||
policy (PreTrainedPolicy): The policy config from which parameters and presets must be taken from.
|
||||
|
||||
Returns:
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.offline.steps) if cfg.scheduler is not None else None
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def load_training_state(checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||
"""
|
||||
Given the checkpoint directory, load the optimizer state, scheduler state, and random state, and
|
||||
return the global training step.
|
||||
"""
|
||||
# TODO(aliberts): use safetensors instead as weights_only=False is unsafe
|
||||
training_state = torch.load(checkpoint_dir / TRAINING_STATE, weights_only=False)
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
raise ValueError("The checkpoint contains a scheduler state_dict, but no LRScheduler was provided.")
|
||||
# Small HACK to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"], optimizer, scheduler
|
||||
70
lerobot/common/optim/optimizers.py
Normal file
70
lerobot/common/optim/optimizers.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import abc
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
lr: float
|
||||
weight_decay: float
|
||||
grad_clip_norm: float
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@classmethod
|
||||
def default_choice_name(cls) -> str | None:
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self) -> torch.optim.Optimizer:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("adam")
|
||||
@dataclass
|
||||
class AdamConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
betas: tuple[float, float] = (0.9, 0.999)
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.Adam(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("adamw")
|
||||
@dataclass
|
||||
class AdamWConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
betas: tuple[float, float] = (0.9, 0.999)
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.AdamW(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("sgd")
|
||||
@dataclass
|
||||
class SGDConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
momentum: float = 0.0
|
||||
dampening: float = 0.0
|
||||
nesterov: bool = False
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
91
lerobot/common/optim/schedulers.py
Normal file
91
lerobot/common/optim/schedulers.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import abc
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
num_warmup_steps: int
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("diffuser")
|
||||
@dataclass
|
||||
class DiffuserSchedulerConfig(LRSchedulerConfig):
|
||||
name: str = "cosine"
|
||||
num_warmup_steps: int | None = None
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
||||
return get_scheduler(**kwargs)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("vqbet")
|
||||
@dataclass
|
||||
class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
num_warmup_steps: int
|
||||
num_vqvae_training_steps: int
|
||||
num_cycles: float = 0.5
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step):
|
||||
if current_step < self.num_vqvae_training_steps:
|
||||
return float(1)
|
||||
else:
|
||||
adjusted_step = current_step - self.num_vqvae_training_steps
|
||||
if adjusted_step < self.num_warmup_steps:
|
||||
return float(adjusted_step) / float(max(1, self.num_warmup_steps))
|
||||
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
||||
max(1, num_training_steps - self.num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by Physical Intelligence to train Pi0"""
|
||||
|
||||
num_warmup_steps: int
|
||||
num_decay_steps: int
|
||||
peak_lr: float
|
||||
decay_lr: float
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
del num_training_steps
|
||||
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
return 1 / (self.num_warmup_steps + 1)
|
||||
frac = 1 - current_step / self.num_warmup_steps
|
||||
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
|
||||
if current_step < self.num_warmup_steps:
|
||||
return linear_warmup_schedule(current_step)
|
||||
|
||||
return cosine_decay_schedule(current_step)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
5
lerobot/common/policies/__init__.py
Normal file
5
lerobot/common/policies/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
@@ -15,9 +15,14 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("act")
|
||||
@dataclass
|
||||
class ACTConfig:
|
||||
class ACTConfig(PreTrainedConfig):
|
||||
"""Configuration class for the Action Chunking Transformers policy.
|
||||
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
@@ -90,28 +95,11 @@ class ACTConfig:
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [14],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "mean_std",
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -144,7 +132,14 @@ class ACTConfig:
|
||||
dropout: float = 0.1
|
||||
kl_weight: float = 10.0
|
||||
|
||||
# Training preset
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
optimizer_lr_backbone: float = 1e-5
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
@@ -164,8 +159,28 @@ class ACTConfig:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
if (
|
||||
not any(k.startswith("observation.image") for k in self.input_shapes)
|
||||
and "observation.environment_state" not in self.input_shapes
|
||||
):
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@@ -29,32 +29,27 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class ACTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "act"],
|
||||
):
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
"""
|
||||
|
||||
config_class = ACTConfig
|
||||
name = "act"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig | None = None,
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -64,30 +59,46 @@ class ACTPolicy(
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = ACTConfig()
|
||||
self.config: ACTConfig = config
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
# TODO(aliberts, rcadene): As of now, lr_backbone == lr
|
||||
# Should we remove this and just `return self.parameters()`?
|
||||
return [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": self.config.optimizer_lr_backbone,
|
||||
},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
@@ -106,9 +117,11 @@ class ACTPolicy(
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
@@ -134,9 +147,11 @@ class ACTPolicy(
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
@@ -288,31 +303,30 @@ class ACT(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
self.use_robot_state = "observation.state" in config.input_shapes
|
||||
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if self.config.use_vae:
|
||||
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
self.config.robot_state_feature.shape[0], config.dim_model
|
||||
)
|
||||
# Projection layer for action (joint-space target) to hidden dimension.
|
||||
self.vae_encoder_action_input_proj = nn.Linear(
|
||||
config.output_shapes["action"][0], config.dim_model
|
||||
self.config.action_feature.shape[0],
|
||||
config.dim_model,
|
||||
)
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
@@ -320,7 +334,7 @@ class ACT(nn.Module):
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
if self.use_images:
|
||||
if self.config.image_features:
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
@@ -337,27 +351,27 @@ class ACT(nn.Module):
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
self.config.robot_state_feature.shape[0], config.dim_model
|
||||
)
|
||||
if self.use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
self.encoder_env_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.dim_model
|
||||
self.config.env_state_feature.shape[0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||
if self.use_images:
|
||||
if self.config.image_features:
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
n_1d_tokens = 1 # for the latent
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
n_1d_tokens += 1
|
||||
if self.use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.use_images:
|
||||
if self.config.image_features:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
@@ -365,7 +379,7 @@ class ACT(nn.Module):
|
||||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
|
||||
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
@@ -380,13 +394,13 @@ class ACT(nn.Module):
|
||||
|
||||
`batch` should have the following structure:
|
||||
{
|
||||
"observation.state" (optional): (B, state_dim) batch of robot states.
|
||||
[robot_state_feature] (optional): (B, state_dim) batch of robot states.
|
||||
|
||||
"observation.images": (B, n_cameras, C, H, W) batch of images.
|
||||
[image_features]: (B, n_cameras, C, H, W) batch of images.
|
||||
AND/OR
|
||||
"observation.environment_state": (B, env_dim) batch of environment states.
|
||||
[env_state_feature]: (B, env_dim) batch of environment states.
|
||||
|
||||
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||
[action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||
}
|
||||
|
||||
Returns:
|
||||
@@ -411,12 +425,12 @@ class ACT(nn.Module):
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
@@ -430,7 +444,7 @@ class ACT(nn.Module):
|
||||
# sequence depending whether we use the input states or not (cls and robot state)
|
||||
# False means not a padding token.
|
||||
cls_joint_is_pad = torch.full(
|
||||
(batch_size, 2 if self.use_robot_state else 1),
|
||||
(batch_size, 2 if self.config.robot_state_feature else 1),
|
||||
False,
|
||||
device=batch["observation.state"].device,
|
||||
)
|
||||
@@ -463,16 +477,16 @@ class ACT(nn.Module):
|
||||
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
# Environment state token.
|
||||
if self.use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
|
||||
# Camera observation features and positional embeddings.
|
||||
if self.use_images:
|
||||
if self.config.image_features:
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
|
||||
|
||||
@@ -16,9 +16,15 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("diffusion")
|
||||
@dataclass
|
||||
class DiffusionConfig:
|
||||
class DiffusionConfig(PreTrainedConfig):
|
||||
"""Configuration class for DiffusionPolicy.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
@@ -102,26 +108,17 @@ class DiffusionConfig:
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -154,39 +151,23 @@ class DiffusionConfig:
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-6
|
||||
scheduler_name: str = "cosine"
|
||||
scheduler_warmup_steps: int = 500
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
|
||||
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if len(image_keys) > 0:
|
||||
if self.crop_shape is not None:
|
||||
for image_key in image_keys:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
)
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key = next(iter(image_keys))
|
||||
for image_key in image_keys:
|
||||
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||
raise ValueError(
|
||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||
"expect all image shapes to match."
|
||||
)
|
||||
|
||||
supported_prediction_types = ["epsilon", "sample"]
|
||||
if self.prediction_type not in supported_prediction_types:
|
||||
raise ValueError(
|
||||
@@ -207,3 +188,50 @@ class DiffusionConfig:
|
||||
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
||||
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
return DiffuserSchedulerConfig(
|
||||
name=self.scheduler_name,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@@ -31,35 +31,32 @@ import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
|
||||
class DiffusionPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "diffusion-policy"],
|
||||
):
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
"""
|
||||
|
||||
config_class = DiffusionConfig
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig | None = None,
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -69,18 +66,16 @@ class DiffusionPolicy(
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = DiffusionConfig()
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
@@ -88,20 +83,20 @@ class DiffusionPolicy(
|
||||
|
||||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.diffusion.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if len(self.expected_image_keys) > 0:
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
@@ -127,9 +122,11 @@ class DiffusionPolicy(
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -149,9 +146,11 @@ class DiffusionPolicy(
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
@@ -176,12 +175,9 @@ class DiffusionModel(nn.Module):
|
||||
self.config = config
|
||||
|
||||
# Build observation encoders (depending on which observations are provided).
|
||||
global_cond_dim = config.input_shapes["observation.state"][0]
|
||||
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
self._use_images = False
|
||||
self._use_env_state = False
|
||||
if num_images > 0:
|
||||
self._use_images = True
|
||||
global_cond_dim = self.config.robot_state_feature.shape[0]
|
||||
if self.config.image_features:
|
||||
num_images = len(self.config.image_features)
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
||||
self.rgb_encoder = nn.ModuleList(encoders)
|
||||
@@ -189,9 +185,8 @@ class DiffusionModel(nn.Module):
|
||||
else:
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||
if self.config.env_state_feature:
|
||||
global_cond_dim += self.config.env_state_feature.shape[0]
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
@@ -220,7 +215,7 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
|
||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -242,10 +237,10 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
global_cond_feats = [batch["observation.state"]]
|
||||
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
|
||||
global_cond_feats = [batch[OBS_ROBOT]]
|
||||
# Extract image features.
|
||||
if self._use_images:
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
@@ -272,8 +267,8 @@ class DiffusionModel(nn.Module):
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
if self._use_env_state:
|
||||
global_cond_feats.append(batch["observation.environment_state"])
|
||||
if self.config.env_state_feature:
|
||||
global_cond_feats.append(batch[OBS_ENV])
|
||||
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
@@ -443,7 +438,7 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
|
||||
class DiffusionRgbEncoder(nn.Module):
|
||||
"""Encoder an RGB image into a 1D feature vector.
|
||||
"""Encodes an RGB image into a 1D feature vector.
|
||||
|
||||
Includes the ability to normalize and crop the image first.
|
||||
"""
|
||||
@@ -482,19 +477,16 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
@@ -611,7 +603,7 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
|
||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||
# just reverse these.
|
||||
in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list(
|
||||
in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
|
||||
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
|
||||
)
|
||||
|
||||
@@ -666,7 +658,7 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
|
||||
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
||||
|
||||
@@ -13,99 +13,141 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
|
||||
import logging
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
|
||||
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
||||
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
|
||||
logging.warning(
|
||||
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||
)
|
||||
|
||||
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
|
||||
# issues with mutable defaults. This filter changes all lists to tuples.
|
||||
def list_to_tuple(item):
|
||||
return tuple(item) if isinstance(item, list) else item
|
||||
|
||||
policy_cfg = policy_cfg_class(
|
||||
**{
|
||||
k: list_to_tuple(v)
|
||||
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
|
||||
if k in expected_kwargs
|
||||
}
|
||||
)
|
||||
return policy_cfg
|
||||
|
||||
|
||||
def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy, TDMPCConfig
|
||||
return TDMPCPolicy
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
return DiffusionPolicy, DiffusionConfig
|
||||
return DiffusionPolicy
|
||||
elif name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
return ACTPolicy, ACTConfig
|
||||
return ACTPolicy
|
||||
elif name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy, VQBeTConfig
|
||||
return VQBeTPolicy
|
||||
elif name == "pi0":
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
if policy_type == "tdmpc":
|
||||
return TDMPCConfig(**kwargs)
|
||||
elif policy_type == "diffusion":
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "act":
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
def make_policy(
|
||||
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
|
||||
) -> Policy:
|
||||
cfg: PreTrainedConfig,
|
||||
device: str | torch.device,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
env_cfg: EnvConfig | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
"""Make an instance of a policy class.
|
||||
|
||||
This function exists because (for now) we need to parse features from either a dataset or an environment
|
||||
in order to properly dimension and instantiate a policy for that dataset or environment.
|
||||
|
||||
Args:
|
||||
hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is
|
||||
provided, only `hydra_cfg.policy.name` is used while everything else is ignored.
|
||||
pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a
|
||||
directory containing weights saved using `Policy.save_pretrained`. Note that providing this
|
||||
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`.
|
||||
dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must
|
||||
be provided when initializing a new policy, and must not be provided when loading a pretrained
|
||||
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
||||
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
||||
be loaded with the weights from that path.
|
||||
device (str): the device to load the policy onto.
|
||||
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
||||
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
||||
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
||||
provided if ds_meta is not. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: Either ds_meta or env and env_cfg must be provided.
|
||||
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
|
||||
|
||||
Returns:
|
||||
PreTrainedPolicy: _description_
|
||||
"""
|
||||
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
raise ValueError(
|
||||
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||
if bool(ds_meta) == bool(env_cfg):
|
||||
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
|
||||
|
||||
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
|
||||
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
|
||||
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
|
||||
# you want this op to be added in priority during the prototype phase of this feature, please comment on
|
||||
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
||||
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
||||
# slower than running natively on MPS.
|
||||
if cfg.type == "vqbet" and str(device) == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||
policy_cls = get_policy_class(cfg.type)
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
if pretrained_policy_name_or_path is None:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
kwargs = {}
|
||||
if ds_meta is not None:
|
||||
features = dataset_to_policy_features(ds_meta.features)
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
else:
|
||||
if not cfg.pretrained_path:
|
||||
logging.warning(
|
||||
"You are instantiating a policy from scratch and its features are parsed from an environment "
|
||||
"rather than a dataset. Normalization modules inside the policy will have infinite values "
|
||||
"by default without stats from a dataset."
|
||||
)
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
|
||||
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
|
||||
# huggingface_hub should make it possible to avoid the hack:
|
||||
# https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(**kwargs)
|
||||
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
policy.to(device)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
return policy
|
||||
|
||||
@@ -16,10 +16,12 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
@@ -34,12 +36,16 @@ def create_stats_buffers(
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, mode in modes.items():
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
shape = tuple(shapes[key])
|
||||
assert isinstance(norm_mode, NormalizationMode)
|
||||
|
||||
if "image" in key:
|
||||
shape = tuple(ft.shape)
|
||||
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
@@ -52,7 +58,7 @@ def create_stats_buffers(
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if mode == "mean_std":
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
@@ -61,7 +67,7 @@ def create_stats_buffers(
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif mode == "min_max":
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
@@ -71,15 +77,15 @@ def create_stats_buffers(
|
||||
}
|
||||
)
|
||||
|
||||
if stats is not None:
|
||||
if stats:
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if mode == "mean_std":
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone()
|
||||
buffer["std"].data = stats[key]["std"].clone()
|
||||
elif mode == "min_max":
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone()
|
||||
buffer["max"].data = stats[key]["max"].clone()
|
||||
|
||||
@@ -99,8 +105,8 @@ class Normalize(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -122,10 +128,10 @@ class Normalize(nn.Module):
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@@ -133,16 +139,23 @@ class Normalize(nn.Module):
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
@@ -152,7 +165,7 @@ class Normalize(nn.Module):
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
@@ -164,8 +177,8 @@ class Unnormalize(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -187,11 +200,11 @@ class Unnormalize(nn.Module):
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@@ -199,16 +212,23 @@ class Unnormalize(nn.Module):
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
@@ -216,5 +236,5 @@ class Unnormalize(nn.Module):
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
134
lerobot/common/policies/pi0/configuration_pi0.py
Normal file
134
lerobot/common/policies/pi0/configuration_pi0.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@dataclass
|
||||
class PI0Config(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
||||
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Projector
|
||||
proj_width: int = 1024
|
||||
|
||||
# Decoding
|
||||
num_steps: int = 10
|
||||
|
||||
# Attention utils
|
||||
use_cache: bool = True
|
||||
attention_implementation: str = "eager" # or fa2, flex
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = False
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 2.5e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
# TODO: Add EMA
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
|
||||
if self.use_delta_joint_actions_aloha:
|
||||
raise NotImplementedError(
|
||||
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# TODO: implement value error
|
||||
# if not self.image_features and not self.env_state_feature:
|
||||
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
68
lerobot/common/policies/pi0/conversion_scripts/benchmark.py
Normal file
68
lerobot/common/policies/pi0/conversion_scripts/benchmark.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def main():
|
||||
device = "cuda"
|
||||
dataset_repo_id = "danaaubakirova/koch_test"
|
||||
# model_name = "pi0_base"
|
||||
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_torch_dir = "lerobot/pi0"
|
||||
|
||||
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
# To device
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, device, ds_meta=dataset.meta)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
warmup_iters = 10
|
||||
benchmark_iters = 30
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iters):
|
||||
torch.cuda.synchronize()
|
||||
policy.select_action(batch)
|
||||
policy.reset()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
for _ in range(benchmark_iters):
|
||||
policy.select_action(batch)
|
||||
policy.reset()
|
||||
end_event.record()
|
||||
|
||||
# Synchronize and measure time
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time_ms = start_event.elapsed_time(end_event)
|
||||
|
||||
avg_time_per_iter = elapsed_time_ms / benchmark_iters
|
||||
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.inference_mode():
|
||||
main()
|
||||
@@ -0,0 +1,117 @@
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
if tensor.dtype == torch.bool:
|
||||
tensor = tensor.float()
|
||||
print(f"Shape: {tensor.shape}")
|
||||
print(f"Mean: {tensor.mean().item()}")
|
||||
print(f"Std: {tensor.std().item()}")
|
||||
print(f"Min: {tensor.min().item()}")
|
||||
print(f"Max: {tensor.max().item()}")
|
||||
|
||||
|
||||
def main():
|
||||
num_motors = 14
|
||||
device = "cuda"
|
||||
# model_name = "pi0_aloha_towel"
|
||||
model_name = "pi0_aloha_sim"
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
dataset_repo_id = "lerobot/aloha_static_towel"
|
||||
else:
|
||||
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
|
||||
|
||||
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
|
||||
save_dir = Path(f"../openpi/data/{model_name}/save")
|
||||
|
||||
with open(save_dir / "example.pkl", "rb") as f:
|
||||
example = pickle.load(f)
|
||||
with open(save_dir / "outputs.pkl", "rb") as f:
|
||||
outputs = pickle.load(f)
|
||||
with open(save_dir / "noise.pkl", "rb") as f:
|
||||
noise = pickle.load(f)
|
||||
|
||||
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
|
||||
norm_stats = json.load(f)
|
||||
|
||||
# Override stats
|
||||
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
||||
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
del batch["observation.images.cam_low"]
|
||||
elif model_name == "pi0_aloha_sim":
|
||||
batch["observation.images.top"] = batch["observation.images.cam_high"]
|
||||
del batch["observation.images.cam_high"]
|
||||
|
||||
# Batchify
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].unsqueeze(0)
|
||||
elif isinstance(batch[key], str):
|
||||
batch[key] = [batch[key]]
|
||||
else:
|
||||
raise ValueError(f"{key}, {batch[key]}")
|
||||
|
||||
# To device
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
||||
|
||||
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
|
||||
|
||||
from lerobot.common import policies # noqa
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, device, dataset_meta)
|
||||
|
||||
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
||||
# loss_dict["loss"].backward()
|
||||
# print("losses")
|
||||
# display(loss_dict["losses_after_forward"])
|
||||
# print("pi_losses")
|
||||
# display(pi_losses)
|
||||
|
||||
actions = []
|
||||
for _ in range(50):
|
||||
action = policy.select_action(batch, noise=noise)
|
||||
actions.append(action)
|
||||
|
||||
actions = torch.stack(actions, dim=1)
|
||||
pi_actions = batch["action"]
|
||||
print("actions")
|
||||
display(actions)
|
||||
print()
|
||||
print("pi_actions")
|
||||
display(pi_actions)
|
||||
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
|
||||
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
|
||||
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,70 @@
|
||||
from transformers import GemmaConfig, PaliGemmaConfig
|
||||
|
||||
|
||||
def get_paligemma_config(precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
}
|
||||
|
||||
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
|
||||
|
||||
image_size = 224 # image_sizes[variant]
|
||||
patch_size = 14
|
||||
num_image_tokens = (image_size**2) // (patch_size**2)
|
||||
|
||||
config["image_token_index"] = 257152
|
||||
text_config = {
|
||||
"vocab_size": 257152,
|
||||
"num_hidden_layers": 18,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"torch_dtype": precision,
|
||||
"hidden_size": 2048,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"num_attention_heads": 8,
|
||||
"intermediate_size": 16384,
|
||||
"is_encoder_decoder": False,
|
||||
}
|
||||
vision_config = {
|
||||
"torch_dtype": precision,
|
||||
"image_size": image_size,
|
||||
"patch_size": patch_size,
|
||||
"num_image_tokens": num_image_tokens,
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"num_hidden_layers": 27,
|
||||
"num_attention_heads": 16,
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"vision_use_head": False,
|
||||
}
|
||||
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
|
||||
return final_config
|
||||
|
||||
|
||||
def get_gemma_config(precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
}
|
||||
|
||||
config["image_token_index"] = 257152
|
||||
text_config = {
|
||||
"vocab_size": 257152,
|
||||
"num_hidden_layers": 18,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"torch_dtype": precision,
|
||||
"hidden_size": 1024,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"num_attention_heads": 8,
|
||||
"intermediate_size": 4096,
|
||||
"is_encoder_decoder": False,
|
||||
}
|
||||
final_config = GemmaConfig()
|
||||
final_config.update(text_config)
|
||||
return final_config
|
||||
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Convert pi0 parameters from Jax to Pytorch
|
||||
|
||||
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
|
||||
and install the required librairies.
|
||||
|
||||
```bash
|
||||
cd ~/code/openpi
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Example downloading parameters:
|
||||
```bash
|
||||
python
|
||||
>>> import openpi.shared.download as download
|
||||
>>> path='s3://openpi-assets/checkpoints/pi0_base/params'
|
||||
>>> download.maybe_download(path)
|
||||
```
|
||||
|
||||
Converting pi0_base:
|
||||
```python
|
||||
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
|
||||
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
|
||||
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
|
||||
```
|
||||
|
||||
```python
|
||||
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
|
||||
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
|
||||
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
import torch
|
||||
from jax.sharding import SingleDeviceSharding
|
||||
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
|
||||
get_gemma_config,
|
||||
get_paligemma_config,
|
||||
)
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||
|
||||
|
||||
def slice_paligemma_state_dict(state_dict, config):
|
||||
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
||||
|
||||
# fmt: off
|
||||
# patch embeddings
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
|
||||
3, 2, 0, 1
|
||||
)
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
|
||||
# positional embeddings
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
|
||||
-1, config.vision_config.hidden_size
|
||||
)
|
||||
|
||||
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
|
||||
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
|
||||
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
|
||||
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
||||
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
||||
|
||||
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
||||
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
||||
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
||||
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
||||
|
||||
encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
|
||||
encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
|
||||
encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
|
||||
encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
|
||||
encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
|
||||
encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
|
||||
encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
|
||||
encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
|
||||
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
|
||||
state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
|
||||
state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
|
||||
|
||||
# multimodal projector
|
||||
|
||||
state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
|
||||
state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
|
||||
|
||||
# text decoder (gemma)
|
||||
embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
|
||||
state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
|
||||
|
||||
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
|
||||
# TODO verify correctness of layer norm loading
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
||||
|
||||
for i in range(config.text_config.num_hidden_layers):
|
||||
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
|
||||
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
||||
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
||||
|
||||
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
||||
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
||||
|
||||
# output projection.
|
||||
|
||||
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
|
||||
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
||||
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
||||
# mlp layers
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
||||
|
||||
state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
|
||||
state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
|
||||
|
||||
# fmt: on
|
||||
expert_dict = {}
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key not in [
|
||||
f"llm/final_norm_1/scale{suffix}",
|
||||
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/kv_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/q_einsum_1/w{suffix}",
|
||||
f"llm/layers/mlp_1/gating_einsum{suffix}",
|
||||
f"llm/layers/mlp_1/linear{suffix}",
|
||||
f"llm/layers/pre_attention_norm_1/scale{suffix}",
|
||||
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
|
||||
]:
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
expert_dict[key] = value
|
||||
|
||||
return final_state_dict, expert_dict
|
||||
|
||||
|
||||
def slice_gemma_state_dict(state_dict, config, num_expert=1):
|
||||
# fmt: off
|
||||
# text decoder (gemma)
|
||||
# no embedding vector, the expert just has the decoder layers
|
||||
|
||||
embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
|
||||
state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
|
||||
|
||||
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
||||
|
||||
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
|
||||
# TODO verify correctness of layer norm loading
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
||||
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
||||
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
||||
|
||||
# output projection.
|
||||
|
||||
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
|
||||
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
|
||||
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
||||
# mlp layers
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
||||
|
||||
state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
|
||||
state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
|
||||
|
||||
# fmt: on
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
final_state_dict[key] = value
|
||||
return final_state_dict
|
||||
|
||||
|
||||
def flatten_for_memory(tree, parent_key=""):
|
||||
out = {}
|
||||
for k, v in tree.items():
|
||||
new_key = f"{parent_key}/{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
out.update(flatten_for_memory(v, new_key))
|
||||
else:
|
||||
out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
|
||||
return out
|
||||
|
||||
|
||||
def flatten_for_npz(tree, parent_key=""):
|
||||
out = {}
|
||||
for k, v in tree.items():
|
||||
new_key = f"{parent_key}/{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
out.update(flatten_for_npz(v, new_key))
|
||||
else:
|
||||
# bf16/f32 here?
|
||||
out[new_key] = np.array(v)
|
||||
return out
|
||||
|
||||
|
||||
def slice_initial_orbax_checkpoint(checkpoint_dir: str):
|
||||
params_path = pathlib.Path(checkpoint_dir).resolve()
|
||||
checkpointer = ocp.PyTreeCheckpointer()
|
||||
|
||||
metadata = checkpointer.metadata(params_path)
|
||||
print("Metadata keys:", list(metadata.keys()))
|
||||
|
||||
params_name = "params"
|
||||
|
||||
item = {params_name: metadata[params_name]}
|
||||
device = jax.local_devices()[0] # Use the first local device
|
||||
sharding = SingleDeviceSharding(device)
|
||||
restored = checkpointer.restore(
|
||||
params_path,
|
||||
ocp.args.PyTreeRestore(
|
||||
item=item,
|
||||
restore_args=jax.tree_util.tree_map(
|
||||
lambda _: ocp.ArrayRestoreArgs(
|
||||
restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
|
||||
sharding=sharding,
|
||||
),
|
||||
item,
|
||||
),
|
||||
transforms={},
|
||||
),
|
||||
)
|
||||
params = restored[params_name]
|
||||
|
||||
# get params for PaliGemma
|
||||
pali_params = params["PaliGemma"]
|
||||
del params["PaliGemma"]
|
||||
pali_params_flat = flatten_for_npz(pali_params)
|
||||
return {"paligemma_params": pali_params_flat, "projection_params": params}
|
||||
|
||||
|
||||
def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
||||
"""Update dictionary keys by adding a prefix."""
|
||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
||||
# Break down orbax ckpts - they are in OCDBT
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||
# process projection params
|
||||
keys = [
|
||||
"state_proj",
|
||||
"action_in_proj",
|
||||
"action_out_proj",
|
||||
"action_time_mlp_in",
|
||||
"action_time_mlp_out",
|
||||
]
|
||||
|
||||
projection_params = {}
|
||||
for key in keys:
|
||||
kernel_params = initial_params["projection_params"][key]["kernel"]
|
||||
bias_params = initial_params["projection_params"][key]["bias"]
|
||||
if isinstance(kernel_params, dict):
|
||||
weight = kernel_params["value"]
|
||||
bias = bias_params["value"]
|
||||
else:
|
||||
weight = kernel_params
|
||||
bias = bias_params
|
||||
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
|
||||
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
|
||||
|
||||
# Process PaliGemma weights
|
||||
paligemma_config = get_paligemma_config(precision)
|
||||
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
|
||||
initial_params["paligemma_params"], paligemma_config
|
||||
)
|
||||
|
||||
# Process Gemma weights (at this stage they are unused)
|
||||
gemma_config = get_gemma_config(precision)
|
||||
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
|
||||
|
||||
# Instantiate model from configs
|
||||
|
||||
if "pi0_aloha_sim" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
empty_cameras=2,
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
)
|
||||
elif "pi0_aloha_towel" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=True,
|
||||
)
|
||||
elif "pi0_base" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
empty_cameras=0,
|
||||
adapt_to_pi_aloha=False,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
# gemma_config=gemma_config, paligemma_config=paligemma_config)
|
||||
pi0_model = PI0Policy(pi0_config)
|
||||
|
||||
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
|
||||
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
|
||||
projection_params = update_keys_with_prefix(projection_params, "model.")
|
||||
|
||||
# load state dict
|
||||
torch_dtype = PRECISIONS[precision]
|
||||
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
|
||||
pi0_model = pi0_model.to(torch_dtype)
|
||||
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
|
||||
pi0_model.save_pretrained(output_path, safe_serialization=True)
|
||||
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
|
||||
|
||||
# assert that model loads properly
|
||||
del pi0_model
|
||||
PI0Policy.from_pretrained(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--checkpoint_dir",
|
||||
default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
|
||||
type=str,
|
||||
help="Path to the ocdbt checkpoint",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
choices=["float32", "bfloat16", "float16"],
|
||||
default="float32",
|
||||
type=str,
|
||||
help="Precision identifier for model conversion - should match the base checkpoint precision.",
|
||||
)
|
||||
# tokenizer is identical to paligemma, it appears
|
||||
|
||||
parser.add_argument(
|
||||
"--tokenizer_hub_id",
|
||||
default="google/paligemma-3b-pt-224",
|
||||
type=str,
|
||||
help="Hub path to the tokenizer to save",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to save converted weights to",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_pi0_checkpoint(
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
precision=args.precision,
|
||||
tokenizer_id=args.tokenizer_hub_id,
|
||||
output_path=args.output_path,
|
||||
)
|
||||
127
lerobot/common/policies/pi0/flex_attention.py
Normal file
127
lerobot/common/policies/pi0/flex_attention.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(torch.__version__) > Version("2.5.0"):
|
||||
# Ffex attention is only available from torch 2.5 onwards
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_mask_mod_signature,
|
||||
_round_up_to_multiple,
|
||||
create_block_mask,
|
||||
create_mask,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
|
||||
# @torch.compile(dynamic=False)
|
||||
def flex_attention_forward(
|
||||
attention_mask: torch.Tensor,
|
||||
batch_size: int,
|
||||
head_dim: int,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
scaling=None,
|
||||
):
|
||||
"""
|
||||
This is defined out of classes to make compile happy.
|
||||
"""
|
||||
|
||||
original_dtype = query_states.dtype
|
||||
num_att_heads = 8
|
||||
num_key_value_heads = 1
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
key_states = key_states[:, :, :, None, :]
|
||||
key_states = key_states.expand(
|
||||
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :]
|
||||
value_states = value_states.expand(
|
||||
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
query_states = query_states.to(torch.float32)
|
||||
key_states = key_states.to(torch.float32)
|
||||
value_states = value_states.to(torch.float32)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if causal_mask is not None:
|
||||
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
|
||||
|
||||
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
|
||||
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
|
||||
|
||||
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
|
||||
def mask_mod(b, h, q_idx, kv_idx):
|
||||
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
|
||||
return precomputed_mask[b][h][q_idx][kv_idx]
|
||||
|
||||
return mask_mod
|
||||
|
||||
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
|
||||
|
||||
block_size = 128
|
||||
q_len_rounded = _round_up_to_multiple(q_len, block_size)
|
||||
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
|
||||
|
||||
# *CRITICAL* we do need to expand here, else we get a CUDA index error
|
||||
|
||||
pad_q = q_len_rounded - q_len
|
||||
pad_k = kv_len_rounded - kv_len
|
||||
|
||||
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
|
||||
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
|
||||
|
||||
mask_4d = create_mask(
|
||||
mod_fn=mask_mod_fn_orig,
|
||||
B=b_mask,
|
||||
H=h_mask,
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
|
||||
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
|
||||
block_mask = create_block_mask(
|
||||
mask_mod=mask_mod_fn_padded,
|
||||
B=b_mask,
|
||||
H=h_mask,
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
BLOCK_SIZE=block_size,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
|
||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||
attn_output, attention_weights = flex_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True, # because we shaped query/key states for GQA
|
||||
scale=head_dim**-0.5 if scaling is None else scaling,
|
||||
return_lse=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.to(dtype=original_dtype)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
|
||||
attn_output = attn_output.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
|
||||
)
|
||||
return attn_output
|
||||
732
lerobot/common/policies/pi0/modeling_pi0.py
Normal file
732
lerobot/common/policies/pi0/modeling_pi0.py
Normal file
@@ -0,0 +1,732 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||||
|
||||
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
|
||||
Install pi0 extra dependencies:
|
||||
```bash
|
||||
pip install -e ".[pi0]"
|
||||
```
|
||||
|
||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
|
||||
pretrained with VLM default parameters before pi0 finetuning:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertConfig,
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
||||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
||||
return gamma1 / (gamma1 + gamma2)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
return att_2d_masks
|
||||
|
||||
|
||||
def resize_with_pad(img, width, height, pad_value=-1):
|
||||
# assume no-op when width height fits already
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
cur_height, cur_width = img.shape[2:]
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, int(height - resized_height))
|
||||
pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||||
new_vector[..., :current_dim] = vector
|
||||
return new_vector
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def safe_arcsin(value):
|
||||
# This ensures that the input stays within
|
||||
# [−1,1] to avoid invalid values for arcsin
|
||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
def aloha_gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class PI0Policy(PreTrainedPolicy):
|
||||
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = PI0Config
|
||||
name = "pi0"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
)
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
loss = losses.mean()
|
||||
# For backward pass
|
||||
loss_dict["loss"] = loss
|
||||
# For logging
|
||||
loss_dict["l2_loss"] = loss.item()
|
||||
return loss_dict
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
||||
"""
|
||||
images = []
|
||||
img_masks = []
|
||||
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
)
|
||||
|
||||
# Preprocess image features present in the batch
|
||||
for key in present_img_keys:
|
||||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
# Create image features not present in the batch
|
||||
# as fully 0 padded images.
|
||||
for num_empty_cameras in range(len(missing_img_keys)):
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
break
|
||||
img = torch.ones_like(img) * -1
|
||||
mask = torch.zeros_like(mask)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_ROBOT].device
|
||||
tasks = batch["task"]
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
state[:, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
return state
|
||||
|
||||
def _pi_aloha_encode_actions(self, actions):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# Flip the joints again.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
|
||||
class PI0FlowMatching(nn.Module):
|
||||
"""
|
||||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||||
|
||||
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
┌──────────────────────────────┐
|
||||
│ actions │
|
||||
│ ▲ │
|
||||
│ ┌┴─────┐ │
|
||||
│ kv cache │Gemma │ │
|
||||
│ ┌──────────►│Expert│ │
|
||||
│ │ │ │ │
|
||||
│ ┌┴────────┐ │x 10 │ │
|
||||
│ │ │ └▲──▲──┘ │
|
||||
│ │PaliGemma│ │ │ │
|
||||
│ │ │ │ robot state │
|
||||
│ │ │ noise │
|
||||
│ └▲──▲─────┘ │
|
||||
│ │ │ │
|
||||
│ │ image(s) │
|
||||
│ language tokens │
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
||||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
attention_implementation=self.config.attention_implementation,
|
||||
)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
||||
|
||||
# Projections are float32
|
||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
||||
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
|
||||
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
||||
|
||||
self.set_requires_grad()
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for PaliGemma transformer processing.
|
||||
"""
|
||||
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# TODO: remove for loop
|
||||
for (
|
||||
img,
|
||||
img_mask,
|
||||
) in zip(images, img_masks, strict=False):
|
||||
img_emb = self.paligemma_with_expert.embed_image(img)
|
||||
img_emb = img_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask)
|
||||
|
||||
# Create attention masks so that image tokens attend to each other
|
||||
att_masks += [0] * num_img_embs
|
||||
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
|
||||
# Normalize language embeddings
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
# full attention between image and language inputs
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, state, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Embed state
|
||||
state_emb = self.state_proj(state)
|
||||
state_emb = state_emb.to(dtype=torch.bfloat16)
|
||||
embs.append(state_emb[:, None, :])
|
||||
bsize = state_emb.shape[0]
|
||||
dtype = state_emb.dtype
|
||||
device = state_emb.device
|
||||
|
||||
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1]
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
action_emb = self.action_in_proj(noisy_actions)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||
action_time_emb = F.silu(action_time_emb) # swish == silu
|
||||
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
||||
# Original openpi code, upcast attention output
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Compute image and language key value cache
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
return v_t
|
||||
403
lerobot/common/policies/pi0/paligemma_with_expert.py
Normal file
403
lerobot/common/policies/pi0/paligemma_with_expert.py
Normal file
@@ -0,0 +1,403 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.version
|
||||
from pytest import Cache
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
GemmaForCausalLM,
|
||||
PaliGemmaForConditionalGeneration,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.common.policies.pi0.flex_attention import flex_attention_forward
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
"""
|
||||
Applies RoPE positions [B, L] to x [B, L, H, D].
|
||||
"""
|
||||
d_half = x.shape[-1] // 2
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
sin = torch.sin(radians) # .to(dtype=dtype)
|
||||
cos = torch.cos(radians) # .to(dtype=dtype)
|
||||
|
||||
x1, x2 = x.split(d_half, dim=-1)
|
||||
res = torch.empty_like(x)
|
||||
res[..., :d_half] = x1 * cos - x2 * sin
|
||||
res[..., d_half:] = x2 * cos + x1 * sin
|
||||
|
||||
return res.to(dtype)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertConfig(PretrainedConfig):
|
||||
model_type = "PaliGemmaWithExpertModel"
|
||||
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paligemma_config: dict | None = None,
|
||||
gemma_expert_config: dict | None = None,
|
||||
freeze_vision_encoder: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
attention_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.attention_implementation = attention_implementation
|
||||
|
||||
if paligemma_config is None:
|
||||
# Default config from Pi0
|
||||
self.paligemma_config = CONFIG_MAPPING["paligemma"](
|
||||
transformers_version="4.48.1",
|
||||
_vocab_size=257152,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
hidden_size=2048,
|
||||
image_token_index=257152,
|
||||
model_type="paligemma",
|
||||
pad_token_id=0,
|
||||
projection_dim=2048,
|
||||
text_config={
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 16384,
|
||||
"model_type": "gemma",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 18,
|
||||
"num_image_tokens": 256,
|
||||
"num_key_value_heads": 1,
|
||||
"torch_dtype": "float32",
|
||||
"vocab_size": 257152,
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"num_image_tokens": 256,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 2048,
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"torch_dtype": "float32",
|
||||
"vision_use_head": False,
|
||||
},
|
||||
)
|
||||
elif isinstance(self.paligemma_config, dict):
|
||||
# Override Pi0 default config for PaliGemma
|
||||
if "model_type" not in gemma_expert_config:
|
||||
paligemma_config["model_type"] = "paligemma"
|
||||
|
||||
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
self.paligemma_config = cfg_cls(**paligemma_config)
|
||||
|
||||
if gemma_expert_config is None:
|
||||
# Default config from Pi0
|
||||
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
hidden_size=1024,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=4096,
|
||||
max_position_embeddings=8192,
|
||||
model_type="gemma",
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=18,
|
||||
num_key_value_heads=1,
|
||||
pad_token_id=0,
|
||||
rms_norm_eps=1e-06,
|
||||
rope_theta=10000.0,
|
||||
torch_dtype="float32",
|
||||
transformers_version="4.48.1",
|
||||
use_cache=True,
|
||||
vocab_size=257152,
|
||||
)
|
||||
elif isinstance(self.gemma_expert_config, dict):
|
||||
# Override Pi0 default config for Gemma Expert
|
||||
if "model_type" not in gemma_expert_config:
|
||||
gemma_expert_config["model_type"] = "gemma"
|
||||
|
||||
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.train_expert_only and not self.freeze_vision_encoder:
|
||||
raise ValueError(
|
||||
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
|
||||
)
|
||||
|
||||
if self.attention_implementation not in ["eager", "fa2", "flex"]:
|
||||
raise ValueError(
|
||||
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
config_class = PaliGemmaWithExpertConfig
|
||||
|
||||
def __init__(self, config: PaliGemmaWithExpertConfig):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
||||
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
||||
# Remove unused embed_tokens
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_like_physical_intelligence()
|
||||
self.set_requires_grad()
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
for params in self.paligemma.vision_tower.parameters():
|
||||
params.requires_grad = False
|
||||
|
||||
if self.config.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
for params in self.paligemma.parameters():
|
||||
params.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
|
||||
if self.config.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def to_bfloat16_like_physical_intelligence(self):
|
||||
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
|
||||
|
||||
params_to_change_dtype = [
|
||||
"language_model.model.layers",
|
||||
"gemma_expert.model.layers",
|
||||
"vision_tower",
|
||||
"multi_modal",
|
||||
]
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch.bfloat16)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.get_image_features(image)
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.model.embed_tokens(tokens)
|
||||
|
||||
# TODO: break down this huge forward into modules or functions
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||
inputs_embeds: List[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
fill_kv_cache: Optional[bool] = None,
|
||||
):
|
||||
models = [self.paligemma.language_model.model, self.gemma_expert.model]
|
||||
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
# dtype is always the same, batch size too (if > 1 len)
|
||||
# device could be trickier in multi gpu edge cases but that's it
|
||||
if hidden_states is None:
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
head_dim = self.paligemma.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is None:
|
||||
continue
|
||||
layer = models[i].layers[layer_idx]
|
||||
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# hidden_states = hidden_states * normalizer
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# concatenate on the number of embeddings/tokens
|
||||
query_states = torch.cat(query_states, dim=1)
|
||||
key_states = torch.cat(key_states, dim=1)
|
||||
value_states = torch.cat(value_states, dim=1)
|
||||
|
||||
query_states = apply_rope(query_states, position_ids)
|
||||
key_states = apply_rope(key_states, position_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat(
|
||||
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||
)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
att_output = attention_interface(
|
||||
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_output = att_output.to(dtype=torch.bfloat16)
|
||||
|
||||
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
|
||||
if hidden_states is not None:
|
||||
end = start + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
|
||||
|
||||
# TODO: first dropout (by default 0.0)
|
||||
|
||||
# first residual
|
||||
out_emb += hidden_states
|
||||
after_first_residual = out_emb.clone()
|
||||
|
||||
out_emb = layer.post_attention_layernorm(out_emb)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
|
||||
# TODO: second dropout (by default 0.0)
|
||||
|
||||
# second residual
|
||||
out_emb += after_first_residual
|
||||
|
||||
outputs_embeds.append(out_emb)
|
||||
|
||||
start = end
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
inputs_embeds = outputs_embeds
|
||||
|
||||
# final norm
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is not None:
|
||||
out_emb = models[i].norm(hidden_states)
|
||||
outputs_embeds.append(out_emb)
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
return outputs_embeds, past_key_values
|
||||
|
||||
def get_attention_interface(self):
|
||||
if self.config.attention_implementation == "fa2":
|
||||
attention_interface = self.flash_attention_forward
|
||||
elif self.config.attention_implementation == "flex":
|
||||
attention_interface = flex_attention_forward
|
||||
else:
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
|
||||
def flash_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
raise NotImplementedError("FA2 is not implemented (yet)")
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
||||
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
||||
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
|
||||
query_states = query_states.to(dtype=torch.float32)
|
||||
key_states = key_states.to(dtype=torch.float32)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
att_weights *= head_dim**-0.5
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
|
||||
# value_states: batch_size, sequence_length, num_att_heads, head_dim
|
||||
|
||||
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
@@ -1,75 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A protocol that all policies should follow.
|
||||
|
||||
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
||||
subclass a base class.
|
||||
|
||||
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
|
||||
how to implement new policies.
|
||||
"""
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Policy(Protocol):
|
||||
"""The required interface for implementing a policy.
|
||||
|
||||
We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin.
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
"""To be called whenever the environment is reset.
|
||||
|
||||
Does things like clearing caches.
|
||||
"""
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
|
||||
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PolicyWithUpdate(Policy, Protocol):
|
||||
def update(self):
|
||||
"""An update method that is to be called after a training optimization step.
|
||||
|
||||
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
|
||||
target model, or incrementing an internal buffer).
|
||||
"""
|
||||
182
lerobot/common/policies/pretrained.py
Normal file
182
lerobot/common/policies/pretrained.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor
|
||||
from safetensors.torch import save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||
|
||||
DEFAULT_POLICY_CARD = """
|
||||
---
|
||||
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
||||
# Doc / guide: https://huggingface.co/docs/hub/model-cards
|
||||
{{ card_data }}
|
||||
---
|
||||
|
||||
This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot):
|
||||
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
||||
"""
|
||||
|
||||
|
||||
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
"""
|
||||
Base class for policy models.
|
||||
"""
|
||||
|
||||
config_class: None
|
||||
name: None
|
||||
|
||||
def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, PreTrainedConfig):
|
||||
raise ValueError(
|
||||
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
||||
"`PreTrainedConfig`. To create a model from a pretrained model use "
|
||||
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not getattr(cls, "config_class", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: Type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
map_location: str = "cpu",
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""
|
||||
The policy is set in evaluation mode by default using `policy.eval()` (dropout modules are
|
||||
deactivated). To train it, you should first set it back in training mode with `policy.train()`.
|
||||
"""
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
policy.to(map_location)
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
|
||||
load_model_as_safetensor(model, model_file, strict=strict)
|
||||
if map_location != "cpu":
|
||||
logging.warning(
|
||||
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
||||
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
||||
" This leads to a slower loading time."
|
||||
" Please update safetensors to version 0.4.3 or above for improved performance."
|
||||
)
|
||||
model.to(map_location)
|
||||
else:
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
return model
|
||||
|
||||
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
||||
# card = ModelCard.from_template(
|
||||
# card_data=self._hub_mixin_info.model_card_data,
|
||||
# template_str=self._hub_mixin_info.model_card_template,
|
||||
# repo_url=self._hub_mixin_info.repo_url,
|
||||
# docs_url=self._hub_mixin_info.docs_url,
|
||||
# **kwargs,
|
||||
# )
|
||||
# return card
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optim_params(self) -> dict:
|
||||
"""
|
||||
Returns the policy-specific parameters dict to be passed on to the optimizer.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self):
|
||||
"""To be called whenever the environment is reset.
|
||||
|
||||
Does things like clearing caches.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
|
||||
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -16,9 +16,14 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("tdmpc")
|
||||
@dataclass
|
||||
class TDMPCConfig:
|
||||
class TDMPCConfig(PreTrainedConfig):
|
||||
"""Configuration class for TDMPCPolicy.
|
||||
|
||||
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
|
||||
@@ -102,27 +107,19 @@ class TDMPCConfig:
|
||||
"""
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
n_action_repeats: int = 2
|
||||
horizon: int = 5
|
||||
n_action_steps: int = 1
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ENV": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
@@ -159,32 +156,27 @@ class TDMPCConfig:
|
||||
# Target model.
|
||||
target_model_momentum: float = 0.995
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 3e-4
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) > 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
if len(image_keys) > 0:
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
)
|
||||
if self.output_normalization_modes != {"action": "min_max"}:
|
||||
if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
|
||||
raise ValueError(
|
||||
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
|
||||
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
||||
"information."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
if self.n_action_steps > 1:
|
||||
if self.n_action_repeats != 1:
|
||||
raise ValueError(
|
||||
@@ -194,3 +186,35 @@ class TDMPCConfig:
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(lr=self.optimizer_lr)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# There should only be one image key.
|
||||
if len(self.image_features) > 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} handles at most one image for now. Got image keys {self.image_features}."
|
||||
)
|
||||
|
||||
if len(self.image_features) > 0:
|
||||
image_ft = next(iter(self.image_features.values()))
|
||||
if image_ft.shape[-2] != image_ft.shape[-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(self.horizon + 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return list(range(self.horizon))
|
||||
|
||||
@@ -33,21 +33,16 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
|
||||
|
||||
class TDMPCPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "tdmpc"],
|
||||
):
|
||||
class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""Implementation of TD-MPC learning + inference.
|
||||
|
||||
Please note several warnings for this policy.
|
||||
@@ -65,11 +60,10 @@ class TDMPCPolicy(
|
||||
match our xarm environment.
|
||||
"""
|
||||
|
||||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(
|
||||
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
@@ -77,42 +71,28 @@ class TDMPCPolicy(
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = TDMPCConfig()
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = TDMPCTOLD(config)
|
||||
self.model_target = deepcopy(self.model)
|
||||
for param in self.model_target.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
self._use_image = False
|
||||
self._use_env_state = False
|
||||
if len(image_keys) > 0:
|
||||
assert len(image_keys) == 1
|
||||
self._use_image = True
|
||||
self.input_image_key = image_keys[0]
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
|
||||
@@ -122,9 +102,9 @@ class TDMPCPolicy(
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self._use_image:
|
||||
if self.config.image_features:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
@@ -134,9 +114,9 @@ class TDMPCPolicy(
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -151,9 +131,9 @@ class TDMPCPolicy(
|
||||
|
||||
# NOTE: Order of observations matters here.
|
||||
encode_keys = []
|
||||
if self._use_image:
|
||||
if self.config.image_features:
|
||||
encode_keys.append("observation.image")
|
||||
if self._use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
encode_keys.append("observation.environment_state")
|
||||
encode_keys.append("observation.state")
|
||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||
@@ -196,7 +176,7 @@ class TDMPCPolicy(
|
||||
self.config.horizon,
|
||||
self.config.n_pi_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
self.config.action_feature.shape[0],
|
||||
device=device,
|
||||
)
|
||||
if self.config.n_pi_samples > 0:
|
||||
@@ -215,7 +195,7 @@ class TDMPCPolicy(
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
@@ -228,7 +208,7 @@ class TDMPCPolicy(
|
||||
self.config.horizon,
|
||||
self.config.n_gaussian_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
self.config.action_feature.shape[0],
|
||||
device=std.device,
|
||||
)
|
||||
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
|
||||
@@ -330,16 +310,16 @@ class TDMPCPolicy(
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
@@ -347,7 +327,7 @@ class TDMPCPolicy(
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||
if self.config.image_features and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
@@ -360,7 +340,7 @@ class TDMPCPolicy(
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self._use_image else "observation.environment_state"
|
||||
"observation.image" if self.config.image_features else "observation.environment_state"
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
@@ -522,7 +502,7 @@ class TDMPCPolicy(
|
||||
|
||||
# Undo (b, t) -> (t, b).
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
return info
|
||||
@@ -543,7 +523,7 @@ class TDMPCTOLD(nn.Module):
|
||||
self.config = config
|
||||
self._encoder = TDMPCObservationEncoder(config)
|
||||
self._dynamics = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -554,7 +534,7 @@ class TDMPCTOLD(nn.Module):
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self._reward = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -569,12 +549,12 @@ class TDMPCTOLD(nn.Module):
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.output_shapes["action"][0]),
|
||||
nn.Linear(config.mlp_dim, config.action_feature.shape[0]),
|
||||
)
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -714,10 +694,13 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if "observation.image" in config.input_shapes:
|
||||
if config.image_features:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
||||
next(iter(config.image_features.values())).shape[0],
|
||||
config.image_encoder_hidden_dim,
|
||||
7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
@@ -727,9 +710,8 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
|
||||
out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
@@ -738,19 +720,19 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
)
|
||||
if "observation.state" in config.input_shapes:
|
||||
|
||||
if config.robot_state_feature:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
||||
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
|
||||
if config.env_state_feature:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -765,12 +747,16 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
"""
|
||||
feat = []
|
||||
# NOTE: Order of observations matters here.
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
if self.config.image_features:
|
||||
feat.append(
|
||||
flatten_forward_unflatten(
|
||||
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
|
||||
)
|
||||
)
|
||||
if self.config.env_state_feature:
|
||||
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
|
||||
if self.config.robot_state_feature:
|
||||
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -47,3 +48,20 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||
Note: assumes that all parameters have the same dtype.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
||||
|
||||
def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
|
||||
"""
|
||||
Calculates the output shape of a PyTorch module given an input shape.
|
||||
|
||||
Args:
|
||||
module (nn.Module): a PyTorch module
|
||||
input_shape (tuple): A tuple representing the input shape, e.g., (batch_size, channels, height, width)
|
||||
|
||||
Returns:
|
||||
tuple: The output shape of the module.
|
||||
"""
|
||||
dummy_input = torch.zeros(size=input_shape)
|
||||
with torch.inference_mode():
|
||||
output = module(dummy_input)
|
||||
return tuple(output.shape)
|
||||
|
||||
@@ -18,9 +18,15 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vqbet")
|
||||
@dataclass
|
||||
class VQBeTConfig:
|
||||
class VQBeTConfig(PreTrainedConfig):
|
||||
"""Configuration class for VQ-BeT.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
@@ -90,26 +96,13 @@ class VQBeTConfig:
|
||||
n_action_pred_token: int = 3
|
||||
action_chunk_size: int = 5
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -139,29 +132,69 @@ class VQBeTConfig:
|
||||
bet_softmax_temperature: float = 0.1
|
||||
sequentially_select: bool = False
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-6
|
||||
optimizer_vqvae_lr: float = 1e-3
|
||||
optimizer_vqvae_weight_decay: float = 1e-4
|
||||
scheduler_warmup_steps: int = 500
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> VQBeTSchedulerConfig:
|
||||
return VQBeTSchedulerConfig(
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_vqvae_training_steps=self.n_vqvae_training_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# Note: this check was previously performed inside VQBeTRgbEncoder in the form of
|
||||
# assert len(image_keys) == 1
|
||||
if not len(self.image_features) == 1:
|
||||
raise ValueError("You must provide only one image among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for image_key in image_keys:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key = next(iter(image_keys))
|
||||
for image_key in image_keys:
|
||||
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||
"expect all image shapes to match."
|
||||
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from collections import deque
|
||||
from typing import Callable, List
|
||||
@@ -26,29 +25,23 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
|
||||
class VQBeTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vqbet"],
|
||||
):
|
||||
class VQBeTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
VQ-BeT Policy as per "Behavior Generation with Latent Actions"
|
||||
"""
|
||||
|
||||
config_class = VQBeTConfig
|
||||
name = "vqbet"
|
||||
|
||||
def __init__(
|
||||
@@ -63,26 +56,62 @@ class VQBeTPolicy(
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = VQBeTConfig()
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
vqvae_params = (
|
||||
list(self.vqbet.action_head.vqvae_model.encoder.parameters())
|
||||
+ list(self.vqbet.action_head.vqvae_model.decoder.parameters())
|
||||
+ list(self.vqbet.action_head.vqvae_model.vq_layer.parameters())
|
||||
)
|
||||
decay_params, no_decay_params = self.vqbet.policy.configure_parameters()
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(self.vqbet.rgb_encoder.parameters())
|
||||
+ list(self.vqbet.state_projector.parameters())
|
||||
+ list(self.vqbet.rgb_feature_projector.parameters())
|
||||
+ [self.vqbet.action_token]
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_offset.parameters())
|
||||
)
|
||||
|
||||
if self.config.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
|
||||
return [
|
||||
{
|
||||
"params": decay_params,
|
||||
},
|
||||
{
|
||||
"params": vqvae_params,
|
||||
"weight_decay": self.config.optimizer_vqvae_weight_decay,
|
||||
"lr": self.config.optimizer_vqvae_lr,
|
||||
},
|
||||
{
|
||||
"params": no_decay_params,
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
@@ -105,7 +134,7 @@ class VQBeTPolicy(
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -131,7 +160,7 @@ class VQBeTPolicy(
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
@@ -288,14 +317,14 @@ class VQBeTModel(nn.Module):
|
||||
self.config = config
|
||||
|
||||
self.rgb_encoder = VQBeTRgbEncoder(config)
|
||||
self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
self.num_images = len(self.config.image_features)
|
||||
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
|
||||
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
|
||||
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
|
||||
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
@@ -350,10 +379,10 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
# get action features (pass through GPT)
|
||||
features = self.policy(input_tokens)
|
||||
# len(self.config.input_shapes) is the number of different observation modes.
|
||||
# len(self.config.input_features) is the number of different observation modes.
|
||||
# this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
|
||||
self.config.input_shapes
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
|
||||
self.config.input_features
|
||||
)
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
@@ -392,7 +421,7 @@ class VQBeTHead(nn.Module):
|
||||
|
||||
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
|
||||
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
|
||||
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0]`.
|
||||
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.action_feature.shape[0]`.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
@@ -419,7 +448,7 @@ class VQBeTHead(nn.Module):
|
||||
self.vqvae_model.vqvae_num_layers
|
||||
* self.config.vqvae_n_embed
|
||||
* config.action_chunk_size
|
||||
* config.output_shapes["action"][0],
|
||||
* config.action_feature.shape[0],
|
||||
],
|
||||
)
|
||||
# loss
|
||||
@@ -623,84 +652,6 @@ class VQBeTHead(nn.Module):
|
||||
return loss_dict
|
||||
|
||||
|
||||
class VQBeTOptimizer(torch.optim.Adam):
|
||||
def __init__(self, policy, cfg):
|
||||
vqvae_params = (
|
||||
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
|
||||
)
|
||||
decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.rgb_encoder.parameters())
|
||||
+ list(policy.vqbet.state_projector.parameters())
|
||||
+ list(policy.vqbet.rgb_feature_projector.parameters())
|
||||
+ [policy.vqbet.action_token]
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
|
||||
)
|
||||
|
||||
if cfg.policy.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
|
||||
optim_groups = [
|
||||
{
|
||||
"params": decay_params,
|
||||
"weight_decay": cfg.training.adam_weight_decay,
|
||||
"lr": cfg.training.lr,
|
||||
},
|
||||
{
|
||||
"params": vqvae_params,
|
||||
"weight_decay": 0.0001,
|
||||
"lr": cfg.training.vqvae_lr,
|
||||
},
|
||||
{
|
||||
"params": no_decay_params,
|
||||
"weight_decay": 0.0,
|
||||
"lr": cfg.training.lr,
|
||||
},
|
||||
]
|
||||
super().__init__(
|
||||
optim_groups,
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
)
|
||||
|
||||
|
||||
class VQBeTScheduler(nn.Module):
|
||||
def __init__(self, optimizer, cfg):
|
||||
super().__init__()
|
||||
n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
|
||||
|
||||
num_warmup_steps = cfg.training.lr_warmup_steps
|
||||
num_training_steps = cfg.training.offline_steps
|
||||
num_cycles = 0.5
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < n_vqvae_training_steps:
|
||||
return float(1)
|
||||
else:
|
||||
current_step = current_step - n_vqvae_training_steps
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
||||
|
||||
self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
def step(self):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
|
||||
class VQBeTRgbEncoder(nn.Module):
|
||||
"""Encode an RGB image into a 1D feature vector.
|
||||
|
||||
@@ -743,19 +694,15 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
assert len(image_keys) == 1
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
@@ -844,7 +791,7 @@ class VqVae(nn.Module):
|
||||
)
|
||||
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
|
||||
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
@@ -856,7 +803,7 @@ class VqVae(nn.Module):
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
self.config.output_shapes["action"][0] * self.config.action_chunk_size,
|
||||
self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -872,9 +819,9 @@ class VqVae(nn.Module):
|
||||
# given latent vector, this function outputs the decoded action.
|
||||
output = self.decoder(latent)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
|
||||
100
lerobot/common/robot_devices/cameras/configs.py
Normal file
100
lerobot/common/robot_devices/cameras/configs.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@dataclass
|
||||
class OpenCVCameraConfig(CameraConfig):
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
OpenCVCameraConfig(0, 30, 640, 480)
|
||||
OpenCVCameraConfig(0, 60, 640, 480)
|
||||
OpenCVCameraConfig(0, 90, 640, 480)
|
||||
OpenCVCameraConfig(0, 30, 1280, 720)
|
||||
```
|
||||
"""
|
||||
|
||||
camera_index: int
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("intelrealsense")
|
||||
@dataclass
|
||||
class IntelRealSenseCameraConfig(CameraConfig):
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 640, 480)
|
||||
IntelRealSenseCameraConfig(128422271347, 60, 640, 480)
|
||||
IntelRealSenseCameraConfig(128422271347, 90, 640, 480)
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 1280, 720)
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True)
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90)
|
||||
```
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
serial_number: int | None = None
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
use_depth: bool = False
|
||||
force_hardware_reset: bool = True
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# bool is stronger than is None, since it works with empty strings
|
||||
if bool(self.name) and bool(self.serial_number):
|
||||
raise ValueError(
|
||||
f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided."
|
||||
)
|
||||
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
||||
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
raise ValueError(
|
||||
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
||||
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
|
||||
)
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
@@ -11,13 +11,13 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
@@ -94,7 +94,10 @@ def save_images_from_cameras(
|
||||
cameras = []
|
||||
for cam_sn in serial_numbers:
|
||||
print(f"{cam_sn=}")
|
||||
camera = IntelRealSenseCamera(cam_sn, fps=fps, width=width, height=height, mock=mock)
|
||||
config = IntelRealSenseCameraConfig(
|
||||
serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock
|
||||
)
|
||||
camera = IntelRealSenseCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
|
||||
@@ -149,51 +152,6 @@ def save_images_from_cameras(
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntelRealSenseCameraConfig:
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
IntelRealSenseCameraConfig(30, 640, 480)
|
||||
IntelRealSenseCameraConfig(60, 640, 480)
|
||||
IntelRealSenseCameraConfig(90, 640, 480)
|
||||
IntelRealSenseCameraConfig(30, 1280, 720)
|
||||
IntelRealSenseCameraConfig(30, 640, 480, use_depth=True)
|
||||
IntelRealSenseCameraConfig(30, 640, 480, rotation=90)
|
||||
```
|
||||
"""
|
||||
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
use_depth: bool = False
|
||||
force_hardware_reset: bool = True
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
||||
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
raise ValueError(
|
||||
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
||||
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
|
||||
)
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
|
||||
|
||||
class IntelRealSenseCamera:
|
||||
"""
|
||||
The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
|
||||
@@ -209,33 +167,35 @@ class IntelRealSenseCamera:
|
||||
When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
of the given camera will be used.
|
||||
|
||||
Example of usage:
|
||||
Example of instantiating with a serial number:
|
||||
```python
|
||||
# Instantiate with its serial number
|
||||
camera = IntelRealSenseCamera(128422271347)
|
||||
# Or by its name if it's unique
|
||||
camera = IntelRealSenseCamera.init_from_name("Intel RealSense D405")
|
||||
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
|
||||
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347)
|
||||
camera = IntelRealSenseCamera(config)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
# when done using the camera, consider disconnecting
|
||||
camera.disconnect()
|
||||
```
|
||||
|
||||
Example of instantiating with a name if it's unique:
|
||||
```
|
||||
config = IntelRealSenseCameraConfig(name="Intel RealSense D405")
|
||||
```
|
||||
|
||||
Example of changing default fps, width, height and color_mode:
|
||||
```python
|
||||
camera = IntelRealSenseCamera(serial_number, fps=30, width=1280, height=720)
|
||||
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera
|
||||
|
||||
camera = IntelRealSenseCamera(serial_number, fps=90, width=640, height=480)
|
||||
camera = connect()
|
||||
|
||||
camera = IntelRealSenseCamera(serial_number, fps=90, width=640, height=480, color_mode="bgr")
|
||||
camera = connect()
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720)
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480)
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr")
|
||||
# Note: might error out upon `camera.connect()` if these settings are not compatible with the camera
|
||||
```
|
||||
|
||||
Example of returning depth:
|
||||
```python
|
||||
camera = IntelRealSenseCamera(serial_number, use_depth=True)
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, use_depth=True)
|
||||
camera = IntelRealSenseCamera(config)
|
||||
camera.connect()
|
||||
color_image, depth_map = camera.read()
|
||||
```
|
||||
@@ -243,17 +203,13 @@ class IntelRealSenseCamera:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
serial_number: int,
|
||||
config: IntelRealSenseCameraConfig | None = None,
|
||||
**kwargs,
|
||||
config: IntelRealSenseCameraConfig,
|
||||
):
|
||||
if config is None:
|
||||
config = IntelRealSenseCameraConfig()
|
||||
|
||||
# Overwrite the config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
|
||||
self.serial_number = serial_number
|
||||
self.config = config
|
||||
if config.name is not None:
|
||||
self.serial_number = self.find_serial_number_from_name(config.name)
|
||||
else:
|
||||
self.serial_number = config.serial_number
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
@@ -285,8 +241,7 @@ class IntelRealSenseCamera:
|
||||
elif config.rotation == 180:
|
||||
self.rotation = cv2.ROTATE_180
|
||||
|
||||
@classmethod
|
||||
def init_from_name(cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs):
|
||||
def find_serial_number_from_name(self, name):
|
||||
camera_infos = find_cameras()
|
||||
camera_names = [cam["name"] for cam in camera_infos]
|
||||
this_name_count = Counter(camera_names)[name]
|
||||
@@ -299,13 +254,7 @@ class IntelRealSenseCamera:
|
||||
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
|
||||
cam_sn = name_to_serial_dict[name]
|
||||
|
||||
if config is None:
|
||||
config = IntelRealSenseCameraConfig()
|
||||
|
||||
# Overwrite the config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
|
||||
return cls(serial_number=cam_sn, config=config, **kwargs)
|
||||
return cam_sn
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
|
||||
@@ -9,13 +9,13 @@ import platform
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
@@ -126,7 +126,8 @@ def save_images_from_cameras(
|
||||
print("Connecting cameras")
|
||||
cameras = []
|
||||
for cam_idx in camera_ids:
|
||||
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height, mock=mock)
|
||||
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
|
||||
@@ -175,39 +176,6 @@ def save_images_from_cameras(
|
||||
print(f"Images have been saved to {images_dir}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenCVCameraConfig:
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
OpenCVCameraConfig(30, 640, 480)
|
||||
OpenCVCameraConfig(60, 640, 480)
|
||||
OpenCVCameraConfig(90, 640, 480)
|
||||
OpenCVCameraConfig(30, 1280, 720)
|
||||
```
|
||||
"""
|
||||
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
|
||||
|
||||
class OpenCVCamera:
|
||||
"""
|
||||
The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
|
||||
@@ -227,7 +195,10 @@ class OpenCVCamera:
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
camera = OpenCVCamera(camera_index=0)
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
|
||||
config = OpenCVCameraConfig(camera_index=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
# when done using the camera, consider disconnecting
|
||||
@@ -236,25 +207,16 @@ class OpenCVCamera:
|
||||
|
||||
Example of changing default fps, width, height and color_mode:
|
||||
```python
|
||||
camera = OpenCVCamera(0, fps=30, width=1280, height=720)
|
||||
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera
|
||||
|
||||
camera = OpenCVCamera(0, fps=90, width=640, height=480)
|
||||
camera = connect()
|
||||
|
||||
camera = OpenCVCamera(0, fps=90, width=640, height=480, color_mode="bgr")
|
||||
camera = connect()
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=30, width=1280, height=720)
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480)
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480, color_mode="bgr")
|
||||
# Note: might error out open `camera.connect()` if these settings are not compatible with the camera
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = OpenCVCameraConfig()
|
||||
|
||||
# Overwrite config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
|
||||
self.camera_index = camera_index
|
||||
def __init__(self, config: OpenCVCameraConfig):
|
||||
self.config = config
|
||||
self.camera_index = config.camera_index
|
||||
self.port = None
|
||||
|
||||
# Linux uses ports for connecting to cameras
|
||||
@@ -266,7 +228,7 @@ class OpenCVCamera:
|
||||
# Retrieve the camera index from a potentially symlinked path
|
||||
self.camera_index = get_camera_index_from_unix_port(self.port)
|
||||
else:
|
||||
raise ValueError(f"Please check the provided camera_index: {camera_index}")
|
||||
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
||||
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
|
||||
@@ -2,6 +2,12 @@ from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.cameras.configs import (
|
||||
CameraConfig,
|
||||
IntelRealSenseCameraConfig,
|
||||
OpenCVCameraConfig,
|
||||
)
|
||||
|
||||
|
||||
# Defines a camera type
|
||||
class Camera(Protocol):
|
||||
@@ -9,3 +15,39 @@ class Camera(Protocol):
|
||||
def read(self, temporary_color: str | None = None) -> np.ndarray: ...
|
||||
def async_read(self) -> np.ndarray: ...
|
||||
def disconnect(self): ...
|
||||
|
||||
|
||||
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[Camera]:
|
||||
cameras = {}
|
||||
|
||||
for key, cfg in camera_configs.items():
|
||||
if cfg.type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
cameras[key] = OpenCVCamera(cfg)
|
||||
|
||||
elif cfg.type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||
|
||||
cameras[key] = IntelRealSenseCamera(cfg)
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
def make_camera(camera_type, **kwargs) -> Camera:
|
||||
if camera_type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
config = OpenCVCameraConfig(**kwargs)
|
||||
return OpenCVCamera(config)
|
||||
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||
|
||||
config = IntelRealSenseCameraConfig(**kwargs)
|
||||
return IntelRealSenseCamera(config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
||||
|
||||
146
lerobot/common/robot_devices/control_configs.py
Normal file
146
lerobot/common/robot_devices/control_configs.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlConfig(draccus.ChoiceRegistry):
|
||||
pass
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("calibrate")
|
||||
@dataclass
|
||||
class CalibrateControlConfig(ControlConfig):
|
||||
# List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`)
|
||||
arms: list[str] | None = None
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("teleoperate")
|
||||
@dataclass
|
||||
class TeleoperateControlConfig(ControlConfig):
|
||||
# Limit the maximum frames per second. By default, no limit.
|
||||
fps: int | None = None
|
||||
teleop_time_s: float | None = None
|
||||
# Display all cameras on screen
|
||||
display_cameras: bool = True
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("record")
|
||||
@dataclass
|
||||
class RecordControlConfig(ControlConfig):
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | Path | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
||||
device: str | None = None # cuda | cpu | mps
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int | None = None
|
||||
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
||||
warmup_time_s: int | float = 10
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
# Number of seconds for resetting the environment after each episode.
|
||||
reset_time_s: int | float = 60
|
||||
# Number of episodes to record.
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.
|
||||
run_compute_stats: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
num_image_writer_processes: int = 0
|
||||
# Number of threads writing the frames as png images on disk, per camera.
|
||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Display all cameras on screen
|
||||
display_cameras: bool = True
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
|
||||
local_files_only: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("control.policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("control.policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
# When no device or use_amp are given, use the one from training config.
|
||||
if self.device is None or self.use_amp is None:
|
||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
||||
if self.device is None:
|
||||
self.device = train_cfg.device
|
||||
if self.use_amp is None:
|
||||
self.use_amp = train_cfg.use_amp
|
||||
|
||||
# Automatically switch to available device if necessary
|
||||
if not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("replay")
|
||||
@dataclass
|
||||
class ReplayControlConfig(ControlConfig):
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str
|
||||
# Index of the episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the dataset fps.
|
||||
fps: int | None = None
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
|
||||
local_files_only: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlPipelineConfig:
|
||||
robot: RobotConfig
|
||||
control: ControlConfig
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["control.policy"]
|
||||
@@ -19,11 +19,9 @@ from termcolor import colored
|
||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import get_features_from_robot
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||||
|
||||
|
||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
@@ -89,10 +87,6 @@ def is_headless():
|
||||
return True
|
||||
|
||||
|
||||
def has_method(_object: object, method_name: str):
|
||||
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
|
||||
|
||||
|
||||
def predict_action(observation, policy, device, use_amp):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
@@ -161,26 +155,6 @@ def init_keyboard_listener():
|
||||
return listener, events
|
||||
|
||||
|
||||
def init_policy(pretrained_policy_name_or_path, policy_overrides):
|
||||
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
use_amp = hydra_cfg.use_amp
|
||||
policy_fps = hydra_cfg.env.fps
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
return policy, policy_fps, device, use_amp
|
||||
|
||||
|
||||
def warmup_record(
|
||||
robot,
|
||||
events,
|
||||
@@ -233,9 +207,9 @@ def control_loop(
|
||||
dataset: LeRobotDataset | None = None,
|
||||
events=None,
|
||||
policy=None,
|
||||
device=None,
|
||||
use_amp=None,
|
||||
fps=None,
|
||||
device: torch.device | str | None = None,
|
||||
use_amp: bool | None = None,
|
||||
fps: int | None = None,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
@@ -253,6 +227,9 @@ def control_loop(
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
if isinstance(device, str):
|
||||
device = get_safe_torch_device(device)
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -324,21 +301,21 @@ def stop_recording(robot, listener, display_cameras):
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy):
|
||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
_, dataset_name = repo_id.split("/")
|
||||
# either repo_id doesnt start with "eval_" and there is no policy
|
||||
# or repo_id starts with "eval_" and there is a policy
|
||||
|
||||
# Check if dataset_name starts with "eval_" but policy is missing
|
||||
if dataset_name.startswith("eval_") and policy is None:
|
||||
if dataset_name.startswith("eval_") and policy_cfg is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
||||
)
|
||||
|
||||
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||
if not dataset_name.startswith("eval_") and policy is not None:
|
||||
if not dataset_name.startswith("eval_") and policy_cfg is not None:
|
||||
raise ValueError(
|
||||
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy})."
|
||||
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
|
||||
)
|
||||
|
||||
|
||||
|
||||
27
lerobot/common/robot_devices/motors/configs.py
Normal file
27
lerobot/common/robot_devices/motors/configs.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass
|
||||
class MotorsBusConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@MotorsBusConfig.register_subclass("dynamixel")
|
||||
@dataclass
|
||||
class DynamixelMotorsBusConfig(MotorsBusConfig):
|
||||
port: str
|
||||
motors: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@MotorsBusConfig.register_subclass("feetech")
|
||||
@dataclass
|
||||
class FeetechMotorsBusConfig(MotorsBusConfig):
|
||||
port: str
|
||||
motors: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
@@ -8,6 +8,7 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
@@ -252,7 +253,6 @@ class JointOutOfRangeError(Exception):
|
||||
|
||||
|
||||
class DynamixelMotorsBus:
|
||||
# TODO(rcadene): Add a script to find the motor indices without DynamixelWizzard2
|
||||
"""
|
||||
The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
|
||||
the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
|
||||
@@ -274,10 +274,11 @@ class DynamixelMotorsBus:
|
||||
motor_index = 6
|
||||
motor_model = "xl330-m288"
|
||||
|
||||
motors_bus = DynamixelMotorsBus(
|
||||
config = DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={motor_name: (motor_index, motor_model)},
|
||||
)
|
||||
motors_bus = DynamixelMotorsBus(config)
|
||||
motors_bus.connect()
|
||||
|
||||
position = motors_bus.read("Present_Position")
|
||||
@@ -293,23 +294,14 @@ class DynamixelMotorsBus:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||
extra_model_resolution: dict[str, int] | None = None,
|
||||
mock=False,
|
||||
config: DynamixelMotorsBusConfig,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.mock = mock
|
||||
self.port = config.port
|
||||
self.motors = config.motors
|
||||
self.mock = config.mock
|
||||
|
||||
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
if extra_model_control_table:
|
||||
self.model_ctrl_table.update(extra_model_control_table)
|
||||
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
if extra_model_resolution:
|
||||
self.model_resolution.update(extra_model_resolution)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
|
||||
@@ -6,6 +6,7 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
@@ -322,10 +323,11 @@ class FeetechMotorsBus:
|
||||
motor_index = 6
|
||||
motor_model = "sts3215"
|
||||
|
||||
motors_bus = FeetechMotorsBus(
|
||||
config = FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={motor_name: (motor_index, motor_model)},
|
||||
)
|
||||
motors_bus = FeetechMotorsBus(config)
|
||||
motors_bus.connect()
|
||||
|
||||
position = motors_bus.read("Present_Position")
|
||||
@@ -341,23 +343,14 @@ class FeetechMotorsBus:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||
extra_model_resolution: dict[str, int] | None = None,
|
||||
mock=False,
|
||||
config: FeetechMotorsBusConfig,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.mock = mock
|
||||
self.port = config.port
|
||||
self.motors = config.motors
|
||||
self.mock = config.mock
|
||||
|
||||
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
if extra_model_control_table:
|
||||
self.model_ctrl_table.update(extra_model_control_table)
|
||||
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
if extra_model_resolution:
|
||||
self.model_resolution.update(extra_model_resolution)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
@@ -367,8 +360,8 @@ class FeetechMotorsBus:
|
||||
self.group_writers = {}
|
||||
self.logs = {}
|
||||
|
||||
self.multi_turn_index = self.multi_turn_index = [0] * len(motors)
|
||||
self.previous_value = self.previous_value = [0] * len(motors)
|
||||
self.multi_turn_index = self.multi_turn_index = [0] * len(self.motors)
|
||||
self.previous_value = self.previous_value = [0] * len(self.motors)
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
from typing import Protocol
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import (
|
||||
DynamixelMotorsBusConfig,
|
||||
FeetechMotorsBusConfig,
|
||||
MotorsBusConfig,
|
||||
)
|
||||
|
||||
|
||||
class MotorsBus(Protocol):
|
||||
def motor_names(self): ...
|
||||
@@ -8,3 +14,40 @@ class MotorsBus(Protocol):
|
||||
def revert_calibration(self): ...
|
||||
def read(self): ...
|
||||
def write(self): ...
|
||||
|
||||
|
||||
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
|
||||
motors_buses = {}
|
||||
|
||||
for key, cfg in motors_bus_configs.items():
|
||||
if cfg.type == "dynamixel":
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
|
||||
motors_buses[key] = DynamixelMotorsBus(cfg)
|
||||
|
||||
elif cfg.type == "feetech":
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
|
||||
|
||||
motors_buses[key] = FeetechMotorsBus(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
||||
return motors_buses
|
||||
|
||||
|
||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
if motor_type == "dynamixel":
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
|
||||
config = DynamixelMotorsBusConfig(**kwargs)
|
||||
return DynamixelMotorsBus(config)
|
||||
|
||||
elif motor_type == "feetech":
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
|
||||
|
||||
config = FeetechMotorsBusConfig(**kwargs)
|
||||
return FeetechMotorsBus(config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
||||
|
||||
516
lerobot/common/robot_devices/robots/configs.py
Normal file
516
lerobot/common/robot_devices/robots/configs.py
Normal file
@@ -0,0 +1,516 @@
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Sequence
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.robot_devices.cameras.configs import (
|
||||
CameraConfig,
|
||||
IntelRealSenseCameraConfig,
|
||||
OpenCVCameraConfig,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.configs import (
|
||||
DynamixelMotorsBusConfig,
|
||||
FeetechMotorsBusConfig,
|
||||
MotorsBusConfig,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): remove ManipulatorRobotConfig abstraction
|
||||
@dataclass
|
||||
class ManipulatorRobotConfig(RobotConfig):
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
|
||||
|
||||
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
|
||||
# as the number of motors in your follower arms (assumes all follower arms have the same number of
|
||||
# motors).
|
||||
max_relative_target: list[float] | float | None = None
|
||||
|
||||
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
|
||||
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
|
||||
# gripper is not put in torque mode.
|
||||
gripper_open_degree: float | None = None
|
||||
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mock:
|
||||
for arm in self.leader_arms.values():
|
||||
if not arm.mock:
|
||||
arm.mock = True
|
||||
for arm in self.follower_arms.values():
|
||||
if not arm.mock:
|
||||
arm.mock = True
|
||||
for cam in self.cameras.values():
|
||||
if not cam.mock:
|
||||
cam.mock = True
|
||||
|
||||
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
|
||||
for name in self.follower_arms:
|
||||
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
|
||||
raise ValueError(
|
||||
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
|
||||
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
|
||||
f"`max_relative_target` list has as many parameters as there are motors per arm. "
|
||||
"Note: This feature does not yet work with robots where different follower arms have "
|
||||
"different numbers of motors."
|
||||
)
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("aloha")
|
||||
@dataclass
|
||||
class AlohaRobotConfig(ManipulatorRobotConfig):
|
||||
# Specific to Aloha, LeRobot comes with default calibration files. Assuming the motors have been
|
||||
# properly assembled, no manual calibration step is expected. If you need to run manual calibration,
|
||||
# simply update this path to ".cache/calibration/aloha"
|
||||
calibration_dir: str = ".cache/calibration/aloha_default"
|
||||
|
||||
# /!\ FOR SAFETY, READ THIS /!\
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
|
||||
# When you feel more confident with teleoperation or running the policy, you can extend
|
||||
# this safety limit and even removing it by setting it to `null`.
|
||||
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
|
||||
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
|
||||
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
|
||||
max_relative_target: int | None = 5
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"left": DynamixelMotorsBusConfig(
|
||||
# window_x
|
||||
port="/dev/ttyDXL_leader_left",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"waist": [1, "xm430-w350"],
|
||||
"shoulder": [2, "xm430-w350"],
|
||||
"shoulder_shadow": [3, "xm430-w350"],
|
||||
"elbow": [4, "xm430-w350"],
|
||||
"elbow_shadow": [5, "xm430-w350"],
|
||||
"forearm_roll": [6, "xm430-w350"],
|
||||
"wrist_angle": [7, "xm430-w350"],
|
||||
"wrist_rotate": [8, "xl430-w250"],
|
||||
"gripper": [9, "xc430-w150"],
|
||||
},
|
||||
),
|
||||
"right": DynamixelMotorsBusConfig(
|
||||
# window_x
|
||||
port="/dev/ttyDXL_leader_right",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"waist": [1, "xm430-w350"],
|
||||
"shoulder": [2, "xm430-w350"],
|
||||
"shoulder_shadow": [3, "xm430-w350"],
|
||||
"elbow": [4, "xm430-w350"],
|
||||
"elbow_shadow": [5, "xm430-w350"],
|
||||
"forearm_roll": [6, "xm430-w350"],
|
||||
"wrist_angle": [7, "xm430-w350"],
|
||||
"wrist_rotate": [8, "xl430-w250"],
|
||||
"gripper": [9, "xc430-w150"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"left": DynamixelMotorsBusConfig(
|
||||
port="/dev/ttyDXL_follower_left",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"waist": [1, "xm540-w270"],
|
||||
"shoulder": [2, "xm540-w270"],
|
||||
"shoulder_shadow": [3, "xm540-w270"],
|
||||
"elbow": [4, "xm540-w270"],
|
||||
"elbow_shadow": [5, "xm540-w270"],
|
||||
"forearm_roll": [6, "xm540-w270"],
|
||||
"wrist_angle": [7, "xm540-w270"],
|
||||
"wrist_rotate": [8, "xm430-w350"],
|
||||
"gripper": [9, "xm430-w350"],
|
||||
},
|
||||
),
|
||||
"right": DynamixelMotorsBusConfig(
|
||||
port="/dev/ttyDXL_follower_right",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"waist": [1, "xm540-w270"],
|
||||
"shoulder": [2, "xm540-w270"],
|
||||
"shoulder_shadow": [3, "xm540-w270"],
|
||||
"elbow": [4, "xm540-w270"],
|
||||
"elbow_shadow": [5, "xm540-w270"],
|
||||
"forearm_roll": [6, "xm540-w270"],
|
||||
"wrist_angle": [7, "xm540-w270"],
|
||||
"wrist_rotate": [8, "xm430-w350"],
|
||||
"gripper": [9, "xm430-w350"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Troubleshooting: If one of your IntelRealSense cameras freeze during
|
||||
# data recording due to bandwidth limit, you might need to plug the camera
|
||||
# on another USB hub or PCIe card.
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"cam_high": IntelRealSenseCameraConfig(
|
||||
serial_number=128422271347,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"cam_low": IntelRealSenseCameraConfig(
|
||||
serial_number=130322270656,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"cam_left_wrist": IntelRealSenseCameraConfig(
|
||||
serial_number=218622272670,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"cam_right_wrist": IntelRealSenseCameraConfig(
|
||||
serial_number=130322272300,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("koch")
|
||||
@dataclass
|
||||
class KochRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/koch"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0085511",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl330-m077"],
|
||||
"shoulder_lift": [2, "xl330-m077"],
|
||||
"elbow_flex": [3, "xl330-m077"],
|
||||
"wrist_flex": [4, "xl330-m077"],
|
||||
"wrist_roll": [5, "xl330-m077"],
|
||||
"gripper": [6, "xl330-m077"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
"shoulder_lift": [2, "xl430-w250"],
|
||||
"elbow_flex": [3, "xl330-m288"],
|
||||
"wrist_flex": [4, "xl330-m288"],
|
||||
"wrist_roll": [5, "xl330-m288"],
|
||||
"gripper": [6, "xl330-m288"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"laptop": OpenCVCameraConfig(
|
||||
camera_index=0,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"phone": OpenCVCameraConfig(
|
||||
camera_index=1,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# ~ Koch specific settings ~
|
||||
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
|
||||
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||
gripper_open_degree: float = 35.156
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("koch_bimanual")
|
||||
@dataclass
|
||||
class KochBimanualRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/koch_bimanual"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"left": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0085511",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl330-m077"],
|
||||
"shoulder_lift": [2, "xl330-m077"],
|
||||
"elbow_flex": [3, "xl330-m077"],
|
||||
"wrist_flex": [4, "xl330-m077"],
|
||||
"wrist_roll": [5, "xl330-m077"],
|
||||
"gripper": [6, "xl330-m077"],
|
||||
},
|
||||
),
|
||||
"right": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl330-m077"],
|
||||
"shoulder_lift": [2, "xl330-m077"],
|
||||
"elbow_flex": [3, "xl330-m077"],
|
||||
"wrist_flex": [4, "xl330-m077"],
|
||||
"wrist_roll": [5, "xl330-m077"],
|
||||
"gripper": [6, "xl330-m077"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"left": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
"shoulder_lift": [2, "xl430-w250"],
|
||||
"elbow_flex": [3, "xl330-m288"],
|
||||
"wrist_flex": [4, "xl330-m288"],
|
||||
"wrist_roll": [5, "xl330-m288"],
|
||||
"gripper": [6, "xl330-m288"],
|
||||
},
|
||||
),
|
||||
"right": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
"shoulder_lift": [2, "xl430-w250"],
|
||||
"elbow_flex": [3, "xl330-m288"],
|
||||
"wrist_flex": [4, "xl330-m288"],
|
||||
"wrist_roll": [5, "xl330-m288"],
|
||||
"gripper": [6, "xl330-m288"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"laptop": OpenCVCameraConfig(
|
||||
camera_index=0,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"phone": OpenCVCameraConfig(
|
||||
camera_index=1,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# ~ Koch specific settings ~
|
||||
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
|
||||
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||
gripper_open_degree: float = 35.156
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("moss")
|
||||
@dataclass
|
||||
class MossRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/moss"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"laptop": OpenCVCameraConfig(
|
||||
camera_index=0,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"phone": OpenCVCameraConfig(
|
||||
camera_index=1,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so100")
|
||||
@dataclass
|
||||
class So100RobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/so100"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"laptop": OpenCVCameraConfig(
|
||||
camera_index=0,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"phone": OpenCVCameraConfig(
|
||||
camera_index=1,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("stretch")
|
||||
@dataclass
|
||||
class StretchRobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"navigation": OpenCVCameraConfig(
|
||||
camera_index="/dev/hello-nav-head-camera",
|
||||
fps=10,
|
||||
width=1280,
|
||||
height=720,
|
||||
rotation=-90,
|
||||
),
|
||||
"head": IntelRealSenseCameraConfig(
|
||||
name="Intel RealSense D435I",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
rotation=90,
|
||||
),
|
||||
"wrist": IntelRealSenseCameraConfig(
|
||||
name="Intel RealSense D405",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
@@ -1,9 +0,0 @@
|
||||
import hydra
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
|
||||
def make_robot(cfg: DictConfig) -> Robot:
|
||||
robot = hydra.utils.instantiate(cfg)
|
||||
return robot
|
||||
@@ -8,15 +8,14 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
|
||||
@@ -41,50 +40,6 @@ def ensure_safe_goal_position(
|
||||
return safe_goal_pos
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManipulatorRobotConfig:
|
||||
"""
|
||||
Example of usage:
|
||||
```python
|
||||
ManipulatorRobotConfig()
|
||||
```
|
||||
"""
|
||||
|
||||
# Define all components of the robot
|
||||
robot_type: str = "koch"
|
||||
leader_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
|
||||
follower_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
|
||||
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||
|
||||
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
|
||||
# as the number of motors in your follower arms (assumes all follower arms have the same number of
|
||||
# motors).
|
||||
max_relative_target: list[float] | float | None = None
|
||||
|
||||
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
|
||||
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
|
||||
# gripper is not put in torque mode.
|
||||
gripper_open_degree: float | None = None
|
||||
|
||||
def __setattr__(self, prop: str, val):
|
||||
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
|
||||
for name in self.follower_arms:
|
||||
if len(self.follower_arms[name].motors) != len(val):
|
||||
raise ValueError(
|
||||
f"len(max_relative_target)={len(val)} but the follower arm with name {name} has "
|
||||
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
|
||||
f"`max_relative_target` list has as many parameters as there are motors per arm. "
|
||||
"Note: This feature does not yet work with robots where different follower arms have "
|
||||
"different numbers of motors."
|
||||
)
|
||||
super().__setattr__(prop, val)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.robot_type not in ["koch", "koch_bimanual", "aloha", "so100", "moss"]:
|
||||
raise ValueError(f"Provided robot type ({self.robot_type}) is not supported.")
|
||||
|
||||
|
||||
class ManipulatorRobot:
|
||||
# TODO(rcadene): Implement force feedback
|
||||
"""This class allows to control any manipulator robot of various number of motors.
|
||||
@@ -95,11 +50,16 @@ class ManipulatorRobot:
|
||||
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
|
||||
- [Aloha](https://www.trossenrobotics.com/aloha-kits) developed by Trossen Robotics
|
||||
|
||||
Example of highest frequency teleoperation without camera:
|
||||
Example of instantiation, a pre-defined robot config is required:
|
||||
```python
|
||||
robot = ManipulatorRobot(KochRobotConfig())
|
||||
```
|
||||
|
||||
Example of overwritting motors during instantiation:
|
||||
```python
|
||||
# Defines how to communicate with the motors of the leader and follower arms
|
||||
leader_arms = {
|
||||
"main": DynamixelMotorsBus(
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
@@ -113,7 +73,7 @@ class ManipulatorRobot:
|
||||
),
|
||||
}
|
||||
follower_arms = {
|
||||
"main": DynamixelMotorsBus(
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
@@ -126,35 +86,11 @@ class ManipulatorRobot:
|
||||
},
|
||||
),
|
||||
}
|
||||
robot = ManipulatorRobot(
|
||||
robot_type="koch",
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
leader_arms=leader_arms,
|
||||
follower_arms=follower_arms,
|
||||
)
|
||||
|
||||
# Connect motors buses and cameras if any (Required)
|
||||
robot.connect()
|
||||
|
||||
while True:
|
||||
robot.teleop_step()
|
||||
robot_config = KochRobotConfig(leader_arms=leader_arms, follower_arms=follower_arms)
|
||||
robot = ManipulatorRobot(robot_config)
|
||||
```
|
||||
|
||||
Example of highest frequency data collection without camera:
|
||||
```python
|
||||
# Assumes leader and follower arms have been instantiated already (see first example)
|
||||
robot = ManipulatorRobot(
|
||||
robot_type="koch",
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
leader_arms=leader_arms,
|
||||
follower_arms=follower_arms,
|
||||
)
|
||||
robot.connect()
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
```
|
||||
|
||||
Example of highest frequency data collection with cameras:
|
||||
Example of overwritting cameras during instantiation:
|
||||
```python
|
||||
# Defines how to communicate with 2 cameras connected to the computer.
|
||||
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
|
||||
@@ -164,31 +100,28 @@ class ManipulatorRobot:
|
||||
"laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
|
||||
}
|
||||
robot = ManipulatorRobot(KochRobotConfig(cameras=cameras))
|
||||
```
|
||||
|
||||
# Assumes leader and follower arms have been instantiated already (see first example)
|
||||
robot = ManipulatorRobot(
|
||||
robot_type="koch",
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
leader_arms=leader_arms,
|
||||
follower_arms=follower_arms,
|
||||
cameras=cameras,
|
||||
)
|
||||
Once the robot is instantiated, connect motors buses and cameras if any (Required):
|
||||
```python
|
||||
robot.connect()
|
||||
```
|
||||
|
||||
Example of highest frequency teleoperation, which doesn't require cameras:
|
||||
```python
|
||||
while True:
|
||||
robot.teleop_step()
|
||||
```
|
||||
|
||||
Example of highest frequency data collection from motors and cameras (if any):
|
||||
```python
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
```
|
||||
|
||||
Example of controlling the robot with a policy (without running multiple policies in parallel to ensure highest frequency):
|
||||
Example of controlling the robot with a policy:
|
||||
```python
|
||||
# Assumes leader and follower arms + cameras have been instantiated already (see previous example)
|
||||
robot = ManipulatorRobot(
|
||||
robot_type="koch",
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
leader_arms=leader_arms,
|
||||
follower_arms=follower_arms,
|
||||
cameras=cameras,
|
||||
)
|
||||
robot.connect()
|
||||
while True:
|
||||
# Uses the follower arms and cameras to capture an observation
|
||||
observation = robot.capture_observation()
|
||||
@@ -209,20 +142,14 @@ class ManipulatorRobot:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ManipulatorRobotConfig | None = None,
|
||||
calibration_dir: Path = ".cache/calibration/koch",
|
||||
**kwargs,
|
||||
config: ManipulatorRobotConfig,
|
||||
):
|
||||
if config is None:
|
||||
config = ManipulatorRobotConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
self.config = replace(config, **kwargs)
|
||||
self.calibration_dir = Path(calibration_dir)
|
||||
|
||||
self.robot_type = self.config.robot_type
|
||||
self.leader_arms = self.config.leader_arms
|
||||
self.follower_arms = self.config.follower_arms
|
||||
self.cameras = self.config.cameras
|
||||
self.config = config
|
||||
self.robot_type = self.config.type
|
||||
self.calibration_dir = Path(self.config.calibration_dir)
|
||||
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
||||
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
|
||||
@@ -15,23 +15,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field, replace
|
||||
from dataclasses import replace
|
||||
|
||||
import torch
|
||||
from stretch_body.gamepad_teleop import GamePadTeleop
|
||||
from stretch_body.robot import Robot as StretchAPI
|
||||
from stretch_body.robot_params import RobotParams
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
|
||||
|
||||
@dataclass
|
||||
class StretchRobotConfig:
|
||||
robot_type: str | None = "stretch"
|
||||
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||
# TODO(aliberts): add feature with max_relative target
|
||||
# TODO(aliberts): add comment on max_relative target
|
||||
max_relative_target: list[float] | float | None = None
|
||||
from lerobot.common.robot_devices.robots.configs import StretchRobotConfig
|
||||
|
||||
|
||||
class StretchRobot(StretchAPI):
|
||||
@@ -40,11 +31,12 @@ class StretchRobot(StretchAPI):
|
||||
def __init__(self, config: StretchRobotConfig | None = None, **kwargs):
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = StretchRobotConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
self.config = replace(config, **kwargs)
|
||||
self.config = StretchRobotConfig(**kwargs)
|
||||
else:
|
||||
# Overwrite config arguments using kwargs
|
||||
self.config = replace(config, **kwargs)
|
||||
|
||||
self.robot_type = self.config.robot_type
|
||||
self.robot_type = self.config.type
|
||||
self.cameras = self.config.cameras
|
||||
self.is_connected = False
|
||||
self.teleop = None
|
||||
|
||||
@@ -1,5 +1,16 @@
|
||||
from typing import Protocol
|
||||
|
||||
from lerobot.common.robot_devices.robots.configs import (
|
||||
AlohaRobotConfig,
|
||||
KochBimanualRobotConfig,
|
||||
KochRobotConfig,
|
||||
ManipulatorRobotConfig,
|
||||
MossRobotConfig,
|
||||
RobotConfig,
|
||||
So100RobotConfig,
|
||||
StretchRobotConfig,
|
||||
)
|
||||
|
||||
|
||||
def get_arm_id(name, arm_type):
|
||||
"""Returns the string identifier of a robot arm. For instance, for a bimanual manipulator
|
||||
@@ -19,3 +30,36 @@ class Robot(Protocol):
|
||||
def capture_observation(self): ...
|
||||
def send_action(self, action): ...
|
||||
def disconnect(self): ...
|
||||
|
||||
|
||||
def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||
if robot_type == "aloha":
|
||||
return AlohaRobotConfig(**kwargs)
|
||||
elif robot_type == "koch":
|
||||
return KochRobotConfig(**kwargs)
|
||||
elif robot_type == "koch_bimanual":
|
||||
return KochBimanualRobotConfig(**kwargs)
|
||||
elif robot_type == "moss":
|
||||
return MossRobotConfig(**kwargs)
|
||||
elif robot_type == "so100":
|
||||
return So100RobotConfig(**kwargs)
|
||||
elif robot_type == "stretch":
|
||||
return StretchRobotConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||
|
||||
|
||||
def make_robot_from_config(config: RobotConfig):
|
||||
if isinstance(config, ManipulatorRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||
|
||||
return ManipulatorRobot(config)
|
||||
else:
|
||||
from lerobot.common.robot_devices.robots.stretch import StretchRobot
|
||||
|
||||
return StretchRobot(config)
|
||||
|
||||
|
||||
def make_robot(robot_type: str, **kwargs) -> Robot:
|
||||
config = make_robot_config(robot_type, **kwargs)
|
||||
return make_robot_from_config(config)
|
||||
|
||||
188
lerobot/common/utils/hub.py
Normal file
188
lerobot/common/utils/hub.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
T = TypeVar("T", bound="HubMixin")
|
||||
|
||||
|
||||
class HubMixin:
|
||||
"""
|
||||
A Mixin containing the functionality to push an object to the hub.
|
||||
|
||||
This is similar to huggingface_hub.ModelHubMixin but is lighter and makes less assumptions about its
|
||||
subclasses (in particular, the fact that it's not necessarily a model).
|
||||
|
||||
The inheriting classes must implement '_save_pretrained' and 'from_pretrained'.
|
||||
"""
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: str | Path,
|
||||
*,
|
||||
repo_id: str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
card_kwargs: dict[str, Any] | None = None,
|
||||
**push_to_hub_kwargs,
|
||||
) -> str | None:
|
||||
"""
|
||||
Save object in local directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `Path`):
|
||||
Path to directory in which the object will be saved.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your object to the Huggingface Hub after saving it.
|
||||
repo_id (`str`, *optional*):
|
||||
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
||||
not provided.
|
||||
card_kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional arguments passed to the card template to customize the card.
|
||||
push_to_hub_kwargs:
|
||||
Additional key word arguments passed along to the [`~HubMixin.push_to_hub`] method.
|
||||
Returns:
|
||||
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
|
||||
"""
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# save object (weights, files, etc.)
|
||||
self._save_pretrained(save_directory)
|
||||
|
||||
# push to the Hub if required
|
||||
if push_to_hub:
|
||||
if repo_id is None:
|
||||
repo_id = save_directory.name # Defaults to `save_directory` name
|
||||
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""
|
||||
Overwrite this method in subclass to define how to save your object.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `Path`):
|
||||
Path to directory in which the object files will be saved.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
cls: Type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""
|
||||
Download the object from the Huggingface Hub and instantiate it.
|
||||
|
||||
Args:
|
||||
pretrained_name_or_path (`str`, `Path`):
|
||||
- Either the `repo_id` (string) of the object hosted on the Hub, e.g. `lerobot/diffusion_pusht`.
|
||||
- Or a path to a `directory` containing the object files saved using `.save_pretrained`,
|
||||
e.g., `../path/to/my_model_directory/`.
|
||||
revision (`str`, *optional*):
|
||||
Revision on the Hub. Can be a branch name, a git tag or any commit id.
|
||||
Defaults to the latest commit on `main` branch.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether to force (re-)downloading the files from the Hub, overriding the existing cache.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
|
||||
token (`str` or `bool`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
||||
cached when running `huggingface-cli login`.
|
||||
cache_dir (`str`, `Path`, *optional*):
|
||||
Path to the folder where cached files are stored.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
||||
kwargs (`Dict`, *optional*):
|
||||
Additional kwargs to pass to the object during initialization.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@validate_hf_hub_args
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_id: str,
|
||||
*,
|
||||
commit_message: str | None = None,
|
||||
private: bool | None = None,
|
||||
token: str | None = None,
|
||||
branch: str | None = None,
|
||||
create_pr: bool | None = None,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
ignore_patterns: list[str] | str | None = None,
|
||||
delete_patterns: list[str] | str | None = None,
|
||||
card_kwargs: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Upload model checkpoint to the Hub.
|
||||
|
||||
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
|
||||
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
|
||||
details.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
ID of the repository to push to (example: `"username/my-model"`).
|
||||
commit_message (`str`, *optional*):
|
||||
Message to commit while pushing.
|
||||
private (`bool`, *optional*):
|
||||
Whether the repository created should be private.
|
||||
If `None` (default), the repo will be public unless the organization's default is private.
|
||||
token (`str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
||||
cached when running `huggingface-cli login`.
|
||||
branch (`str`, *optional*):
|
||||
The git branch on which to push the model. This defaults to `"main"`.
|
||||
create_pr (`boolean`, *optional*):
|
||||
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
|
||||
allow_patterns (`List[str]` or `str`, *optional*):
|
||||
If provided, only files matching at least one pattern are pushed.
|
||||
ignore_patterns (`List[str]` or `str`, *optional*):
|
||||
If provided, files matching any of the patterns are not pushed.
|
||||
delete_patterns (`List[str]` or `str`, *optional*):
|
||||
If provided, remote files matching any of the patterns will be deleted from the repo.
|
||||
card_kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional arguments passed to the card template to customize the card.
|
||||
|
||||
Returns:
|
||||
The url of the commit of your object in the given repository.
|
||||
"""
|
||||
api = HfApi(token=token)
|
||||
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
|
||||
|
||||
if commit_message is None:
|
||||
if "Policy" in self.__class__.__name__:
|
||||
commit_message = "Upload policy"
|
||||
elif "Config" in self.__class__.__name__:
|
||||
commit_message = "Upload config"
|
||||
else:
|
||||
commit_message = f"Upload {self.__class__.__name__}"
|
||||
|
||||
# Push the files to the repo in a single commit
|
||||
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
||||
saved_path = Path(tmp) / repo_id
|
||||
self.save_pretrained(saved_path, card_kwargs=card_kwargs)
|
||||
return api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
folder_path=saved_path,
|
||||
commit_message=commit_message,
|
||||
revision=branch,
|
||||
create_pr=create_pr,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
delete_patterns=delete_patterns,
|
||||
)
|
||||
@@ -19,14 +19,13 @@ import os.path as osp
|
||||
import platform
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def none_or_int(value):
|
||||
@@ -41,9 +40,22 @@ def inside_slurm():
|
||||
return "SLURM_JOB_ID" in os.environ
|
||||
|
||||
|
||||
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
def auto_select_torch_device() -> torch.device:
|
||||
"""Tries to select automatically a torch device."""
|
||||
if torch.cuda.is_available():
|
||||
logging.info("Cuda backend detected, using cuda.")
|
||||
return torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
logging.info("Metal backend detected, using cuda.")
|
||||
return torch.device("mps")
|
||||
else:
|
||||
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
match cfg_device:
|
||||
match try_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
@@ -55,13 +67,45 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
if log:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
case _:
|
||||
device = torch.device(cfg_device)
|
||||
device = torch.device(try_device)
|
||||
if log:
|
||||
logging.warning(f"Using custom {cfg_device} device.")
|
||||
logging.warning(f"Using custom {try_device} device.")
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
||||
"""
|
||||
mps is currently not compatible with float64
|
||||
"""
|
||||
if isinstance(device, torch.device):
|
||||
device = device.type
|
||||
if device == "mps" and dtype == torch.float64:
|
||||
return torch.float32
|
||||
else:
|
||||
return dtype
|
||||
|
||||
|
||||
def is_torch_device_available(try_device: str) -> bool:
|
||||
if try_device == "cuda":
|
||||
return torch.cuda.is_available()
|
||||
elif try_device == "mps":
|
||||
return torch.backends.mps.is_available()
|
||||
elif try_device == "cpu":
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unknown device '{try_device}.")
|
||||
|
||||
|
||||
def is_amp_available(device: str):
|
||||
if device in ["cuda", "cpu"]:
|
||||
return True
|
||||
elif device == "mps":
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Unknown device '{device}.")
|
||||
|
||||
|
||||
def get_global_random_state() -> dict[str, Any]:
|
||||
"""Get the random state for `random`, `numpy`, and `torch`."""
|
||||
random_state_dict = {
|
||||
@@ -159,22 +203,6 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
)
|
||||
|
||||
|
||||
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
|
||||
"""Initialize a Hydra config given only the path to the relevant config file.
|
||||
|
||||
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
|
||||
"""
|
||||
# TODO(alexander-soare): Resolve configs without Hydra initialization.
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
# Hydra needs a path relative to this file.
|
||||
hydra.initialize(
|
||||
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)),
|
||||
version_base="1.2",
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
return cfg
|
||||
|
||||
|
||||
def print_cuda_memory_usage():
|
||||
"""Use this function to locate and debug memory leak."""
|
||||
import gc
|
||||
@@ -217,3 +245,17 @@ def log_say(text, play_sounds, blocking=False):
|
||||
|
||||
if play_sounds:
|
||||
say(text, blocking)
|
||||
|
||||
|
||||
def get_channel_first_image_shape(image_shape: tuple) -> tuple:
|
||||
shape = copy(image_shape)
|
||||
if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif not (shape[0] < shape[1] and shape[0] < shape[2]):
|
||||
raise ValueError(image_shape)
|
||||
|
||||
return shape
|
||||
|
||||
|
||||
def has_method(cls: object, method_name: str):
|
||||
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
|
||||
|
||||
66
lerobot/configs/default.py
Normal file
66
lerobot/configs/default.py
Normal file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common import (
|
||||
policies, # noqa: F401
|
||||
)
|
||||
from lerobot.common.datasets.transforms import ImageTransformsConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
|
||||
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datsets are provided.
|
||||
repo_id: str
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
local_files_only: bool = False
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = "pyav"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandBConfig:
|
||||
enable: bool = False
|
||||
# Set to true to disable saving an artifact despite training.save_checkpoint=True
|
||||
disable_artifact: bool = False
|
||||
project: str = "lerobot"
|
||||
entity: str | None = None
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalConfig:
|
||||
n_episodes: int = 50
|
||||
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
||||
batch_size: int = 50
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.batch_size > self.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
f"({self.batch_size} > {self.n_episodes}). As a result, {self.batch_size} "
|
||||
f"eval environments will be instantiated, but only {self.n_episodes} will be used. "
|
||||
"This might significantly slow down evaluation. To fix this, you should update your command "
|
||||
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), "
|
||||
f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)."
|
||||
)
|
||||
@@ -1,130 +0,0 @@
|
||||
defaults:
|
||||
- _self_
|
||||
- env: pusht
|
||||
- policy: diffusion
|
||||
|
||||
hydra:
|
||||
run:
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
|
||||
job:
|
||||
name: default
|
||||
|
||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
|
||||
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
||||
# regardless of what's provided with the training command at the time of resumption.
|
||||
resume: false
|
||||
device: cuda # cpu
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: false
|
||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||
# AND for the evaluation environments.
|
||||
seed: ???
|
||||
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
|
||||
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datsets are provided.
|
||||
dataset_repo_id: lerobot/pusht
|
||||
video_backend: pyav
|
||||
|
||||
training:
|
||||
offline_steps: ???
|
||||
|
||||
# Number of workers for the offline training dataloader.
|
||||
num_workers: 4
|
||||
|
||||
batch_size: ???
|
||||
|
||||
eval_freq: ???
|
||||
log_freq: 200
|
||||
save_checkpoint: true
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: ???
|
||||
|
||||
# Online training. Note that the online training loop adopts most of the options above apart from the
|
||||
# dataloader options. Unless otherwise specified.
|
||||
# The online training look looks something like:
|
||||
#
|
||||
# for i in range(online_steps):
|
||||
# do_online_rollout_and_update_online_buffer()
|
||||
# for j in range(online_steps_between_rollouts):
|
||||
# batch = next(dataloader_with_offline_and_online_data)
|
||||
# loss = policy(batch)
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
#
|
||||
online_steps: ???
|
||||
# How many episodes to collect at once when we reach the online rollout part of the training loop.
|
||||
online_rollout_n_episodes: 1
|
||||
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
|
||||
# the policy. Ideally you should set this to by an even divisor or online_rollout_n_episodes.
|
||||
online_rollout_batch_size: 1
|
||||
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
|
||||
online_steps_between_rollouts: null
|
||||
# The proportion of online samples (vs offline samples) to include in the online training batches.
|
||||
online_sampling_ratio: 0.5
|
||||
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
|
||||
online_env_seed: null
|
||||
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
|
||||
# FIFO.
|
||||
online_buffer_capacity: null
|
||||
# The minimum number of frames to have in the online buffer before commencing online training.
|
||||
# If online_buffer_seed_size > online_rollout_n_episodes, the rollout will be run multiple times until the
|
||||
# seed size condition is satisfied.
|
||||
online_buffer_seed_size: 0
|
||||
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
|
||||
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
|
||||
# + eval + environment rendering simultaneously.
|
||||
do_online_rollout_async: false
|
||||
|
||||
image_transforms:
|
||||
# These transforms are all using standard torchvision.transforms.v2
|
||||
# You can find out how these transformations affect images here:
|
||||
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
||||
# We use a custom RandomSubsetApply container to sample them.
|
||||
# For each transform, the following parameters are available:
|
||||
# weight: This represents the multinomial probability (with no replacement)
|
||||
# used for sampling the transform. If the sum of the weights is not 1,
|
||||
# they will be normalized.
|
||||
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
|
||||
# (following uniform distribution) when it's applied.
|
||||
# Set this flag to `true` to enable transforms during training
|
||||
enable: false
|
||||
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
|
||||
# It's an integer in the interval [1, number of available transforms].
|
||||
max_num_transforms: 3
|
||||
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
||||
# Set this to True to apply them in a random order.
|
||||
random_order: false
|
||||
brightness:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
contrast:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
saturation:
|
||||
weight: 1
|
||||
min_max: [0.5, 1.5]
|
||||
hue:
|
||||
weight: 1
|
||||
min_max: [-0.05, 0.05]
|
||||
sharpness:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
||||
batch_size: 1
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: false
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
# Set to true to disable saving an artifact despite save_checkpoint == True
|
||||
disable_artifact: false
|
||||
project: lerobot
|
||||
notes: ""
|
||||
14
lerobot/configs/env/aloha.yaml
vendored
14
lerobot/configs/env/aloha.yaml
vendored
@@ -1,14 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 50
|
||||
|
||||
env:
|
||||
name: aloha
|
||||
task: AlohaInsertion-v0
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
10
lerobot/configs/env/aloha_real.yaml
vendored
10
lerobot/configs/env/aloha_real.yaml
vendored
@@ -1,10 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 18
|
||||
action_dim: 18
|
||||
fps: ${fps}
|
||||
13
lerobot/configs/env/dora_aloha_real.yaml
vendored
13
lerobot/configs/env/dora_aloha_real.yaml
vendored
@@ -1,13 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: dora
|
||||
task: DoraAloha-v0
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
gym:
|
||||
fps: ${fps}
|
||||
10
lerobot/configs/env/koch_real.yaml
vendored
10
lerobot/configs/env/koch_real.yaml
vendored
@@ -1,10 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
10
lerobot/configs/env/moss_real.yaml
vendored
10
lerobot/configs/env/moss_real.yaml
vendored
@@ -1,10 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
17
lerobot/configs/env/pusht.yaml
vendored
17
lerobot/configs/env/pusht.yaml
vendored
@@ -1,17 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 10
|
||||
|
||||
env:
|
||||
name: pusht
|
||||
task: PushT-v0
|
||||
image_size: 96
|
||||
state_dim: 2
|
||||
action_dim: 2
|
||||
fps: ${fps}
|
||||
episode_length: 300
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
visualization_width: 384
|
||||
visualization_height: 384
|
||||
10
lerobot/configs/env/so100_real.yaml
vendored
10
lerobot/configs/env/so100_real.yaml
vendored
@@ -1,10 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
17
lerobot/configs/env/xarm.yaml
vendored
17
lerobot/configs/env/xarm.yaml
vendored
@@ -1,17 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 15
|
||||
|
||||
env:
|
||||
name: xarm
|
||||
task: XarmLift-v0
|
||||
image_size: 84
|
||||
state_dim: 4
|
||||
action_dim: 4
|
||||
fps: ${fps}
|
||||
episode_length: 200
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
visualization_width: 384
|
||||
visualization_height: 384
|
||||
84
lerobot/configs/eval.py
Normal file
84
lerobot/configs/eval.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import datetime as dt
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common import envs, policies # noqa: F401
|
||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import EvalConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalPipelineConfig:
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch
|
||||
# (useful for debugging). This argument is mutually exclusive with `--config`.
|
||||
env: envs.EnvConfig
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
policy: PreTrainedConfig | None = None
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
||||
device: str | None = None # cuda | cpu | mps
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
seed: int | None = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
# When no device or use_amp are given, use the one from training config.
|
||||
if self.device is None or self.use_amp is None:
|
||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
||||
if self.device is None:
|
||||
self.device = train_cfg.device
|
||||
if self.use_amp is None:
|
||||
self.use_amp = train_cfg.use_amp
|
||||
|
||||
# Automatically switch to available device if necessary
|
||||
if not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
else:
|
||||
logging.warning(
|
||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{self.policy.type}"
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
|
||||
if not self.output_dir:
|
||||
now = dt.datetime.now()
|
||||
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||
self.output_dir = Path("outputs/eval") / eval_dir
|
||||
|
||||
if self.device is None:
|
||||
raise ValueError("Set one of the following device: cuda, cpu or mps")
|
||||
elif self.device == "cuda" and self.use_amp is None:
|
||||
raise ValueError("Set 'use_amp' to True or False.")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
125
lerobot/configs/parser.py
Normal file
125
lerobot/configs/parser.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import inspect
|
||||
import sys
|
||||
from argparse import ArgumentError
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.utils.utils import has_method
|
||||
|
||||
PATH_KEY = "path"
|
||||
draccus.set_config_type("json")
|
||||
|
||||
|
||||
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
||||
"""Parses arguments from cli at a given nested attribute level.
|
||||
|
||||
For example, supposing the main script was called with:
|
||||
python myscript.py --arg1=1 --arg2.subarg1=abc --arg2.subarg2=some/path
|
||||
|
||||
If called during execution of myscript.py, get_cli_overrides("arg2") will return:
|
||||
["--subarg1=abc" "--subarg2=some/path"]
|
||||
"""
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
attr_level_args = []
|
||||
detect_string = f"--{field_name}."
|
||||
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
|
||||
for arg in args:
|
||||
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
|
||||
denested_arg = f"--{arg.removeprefix(detect_string)}"
|
||||
attr_level_args.append(denested_arg)
|
||||
|
||||
return attr_level_args
|
||||
|
||||
|
||||
def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
prefix = f"--{arg_name}="
|
||||
for arg in args:
|
||||
if arg.startswith(prefix):
|
||||
return arg[len(prefix) :]
|
||||
return None
|
||||
|
||||
|
||||
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
|
||||
|
||||
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return parse_arg(f"{field_name}.{draccus.CHOICE_TYPE_KEY}", args)
|
||||
|
||||
|
||||
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
||||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||
|
||||
|
||||
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
|
||||
"""
|
||||
Filters command-line arguments related to fields with specific path arguments.
|
||||
|
||||
Args:
|
||||
fields_to_filter (str | list[str]): A single str or a list of str whose arguments need to be filtered.
|
||||
args (Sequence[str] | None): The sequence of command-line arguments to be filtered.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[str]: A filtered list of arguments, with arguments related to the specified
|
||||
fields removed.
|
||||
|
||||
Raises:
|
||||
ArgumentError: If both a path argument (e.g., `--field_name.path`) and a type
|
||||
argument (e.g., `--field_name.type`) are specified for the same field.
|
||||
"""
|
||||
if isinstance(fields_to_filter, str):
|
||||
fields_to_filter = [fields_to_filter]
|
||||
|
||||
filtered_args = args
|
||||
for field in fields_to_filter:
|
||||
if get_path_arg(field, args):
|
||||
if get_type_arg(field, args):
|
||||
raise ArgumentError(
|
||||
argument=None,
|
||||
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
|
||||
)
|
||||
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
|
||||
|
||||
return filtered_args
|
||||
|
||||
|
||||
def wrap(config_path: Path | None = None):
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does two additional things:
|
||||
- Will remove '.path' arguments from CLI in order to process them later on.
|
||||
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
|
||||
initialize it from there to allow to fetch configs from the hub directly
|
||||
"""
|
||||
|
||||
def wrapper_outer(fn):
|
||||
@wraps(fn)
|
||||
def wrapper_inner(*args, **kwargs):
|
||||
argspec = inspect.getfullargspec(fn)
|
||||
argtype = argspec.annotations[argspec.args[0]]
|
||||
if len(args) > 0 and type(args[0]) is argtype:
|
||||
cfg = args[0]
|
||||
args = args[1:]
|
||||
else:
|
||||
cli_args = sys.argv[1:]
|
||||
config_path_cli = parse_arg("config_path", cli_args)
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
path_fields = argtype.__get_path_fields__()
|
||||
cli_args = filter_path_args(path_fields, cli_args)
|
||||
if has_method(argtype, "from_pretrained") and config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
return wrapper_inner
|
||||
|
||||
return wrapper_outer
|
||||
145
lerobot/configs/policies.py
Normal file
145
lerobot/configs/policies.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import abc
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.common.optim.optimizers import OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
||||
T = TypeVar("T", bound="PreTrainedConfig")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""
|
||||
Base configuration class for policy models.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
|
||||
normalization mode to apply.
|
||||
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
||||
the original scale.
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
self.pretrained_path = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractproperty
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractproperty
|
||||
def action_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractproperty
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def validate_features(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def robot_state_feature(self) -> PolicyFeature | None:
|
||||
for _, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.STATE:
|
||||
return ft
|
||||
return None
|
||||
|
||||
@property
|
||||
def env_state_feature(self) -> PolicyFeature | None:
|
||||
for _, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.ENV:
|
||||
return ft
|
||||
return None
|
||||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
for _, ft in self.output_features.items():
|
||||
if ft.type is FeatureType.ACTION:
|
||||
return ft
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: Type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**policy_kwargs,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
if Path(model_id).is_dir():
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
|
||||
# something like --policy.path (in addition to --policy.type)
|
||||
cli_overrides = policy_kwargs.pop("cli_overrides", [])
|
||||
return draccus.parse(cls, config_file, args=cli_overrides)
|
||||
@@ -1,82 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.top:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 100000
|
||||
online_steps: 0
|
||||
eval_freq: 20000
|
||||
save_freq: 20000
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.top: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.top: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -1,121 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_aloha_real.yaml` to train on real-world datasets collected on Aloha or Aloha-2 robots.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, cam_high, cam_low) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training and inference with `control_robot.py`:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_aloha_real \
|
||||
# env=aloha_real
|
||||
# ```
|
||||
#
|
||||
# Example of usage for training and inference with [Dora-rs](https://github.com/dora-rs/dora-lerobot):
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_aloha_real \
|
||||
# env=dora_aloha_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.cam_right_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_left_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.cam_right_wrist: [3, 480, 640]
|
||||
observation.images.cam_left_wrist: [3, 480, 640]
|
||||
observation.images.cam_high: [3, 480, 640]
|
||||
observation.images.cam_low: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.cam_right_wrist: mean_std
|
||||
observation.images.cam_left_wrist: mean_std
|
||||
observation.images.cam_high: mean_std
|
||||
observation.images.cam_low: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -1,102 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
|
||||
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_koch_real \
|
||||
# env=koch_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/koch_pick_place_lego
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -1,102 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
|
||||
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_koch_real \
|
||||
# env=koch_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/moss_pick_place_lego
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -1,102 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
|
||||
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_koch_real \
|
||||
# env=koch_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/so100_pick_place_lego
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -1,104 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy.
|
||||
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
|
||||
# https://github.com/huggingface/lerobot/pull/134 for more details.
|
||||
|
||||
seed: 100000
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
override_dataset_stats:
|
||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||
observation.image:
|
||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
||||
# from the original codebase, but we should remove these and train our own pretrained model
|
||||
observation.state:
|
||||
min: [13.456424, 32.938293]
|
||||
max: [496.14618, 510.9579]
|
||||
action:
|
||||
min: [12.0, 25.0]
|
||||
max: [511.0, 511.0]
|
||||
|
||||
training:
|
||||
offline_steps: 200000
|
||||
online_steps: 0
|
||||
eval_freq: 25000
|
||||
save_freq: 25000
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
grad_clip_norm: 10
|
||||
lr: 1.0e-4
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
adam_betas: [0.95, 0.999]
|
||||
adam_eps: 1.0e-8
|
||||
adam_weight_decay: 1.0e-6
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
policy:
|
||||
name: diffusion
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 2
|
||||
horizon: 16
|
||||
n_action_steps: 8
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.image: mean_std
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
crop_shape: [84, 84]
|
||||
crop_is_random: True
|
||||
pretrained_backbone_weights: null
|
||||
use_group_norm: True
|
||||
spatial_softmax_num_keypoints: 32
|
||||
# Unet.
|
||||
down_dims: [512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
diffusion_step_embed_dim: 128
|
||||
use_film_scale_modulation: True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: DDPM
|
||||
num_train_timesteps: 100
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
prediction_type: epsilon # epsilon / sample
|
||||
clip_sample: True
|
||||
clip_sample_range: 1.0
|
||||
|
||||
# Inference
|
||||
num_inference_steps: null # if not provided, defaults to `num_train_timesteps`
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: false
|
||||
@@ -1,110 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Defaults for training for the pusht_keypoints dataset.
|
||||
|
||||
# They keypoints are on the vertices of the rectangles that make up the PushT as documented in the PushT
|
||||
# environment:
|
||||
# https://github.com/huggingface/gym-pusht/blob/5e2489be9ff99ed9cd47b6c653dda3b7aa844d24/gym_pusht/envs/pusht.py#L522-L534
|
||||
# For completeness, the diagram is copied here:
|
||||
# 0───────────1
|
||||
# │ │
|
||||
# 3───4───5───2
|
||||
# │ │
|
||||
# │ │
|
||||
# │ │
|
||||
# │ │
|
||||
# 7───6
|
||||
|
||||
|
||||
# Note: The original work trains keypoints-only with conditioning via inpainting. Here, we encode the
|
||||
# observation along with the agent position and use the encoding as global conditioning for the denoising
|
||||
# U-Net.
|
||||
|
||||
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
|
||||
# https://github.com/huggingface/lerobot/pull/134 for more details.
|
||||
|
||||
seed: 100000
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 200000
|
||||
online_steps: 0
|
||||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
grad_clip_norm: 10
|
||||
lr: 1.0e-4
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
adam_betas: [0.95, 0.999]
|
||||
adam_eps: 1.0e-8
|
||||
adam_weight_decay: 1.0e-6
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
policy:
|
||||
name: diffusion
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 2
|
||||
horizon: 16
|
||||
n_action_steps: 8
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
crop_shape: [84, 84]
|
||||
crop_is_random: True
|
||||
pretrained_backbone_weights: null
|
||||
use_group_norm: True
|
||||
spatial_softmax_num_keypoints: 32
|
||||
# Unet.
|
||||
down_dims: [256, 512, 1024]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
diffusion_step_embed_dim: 128
|
||||
use_film_scale_modulation: True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: DDIM
|
||||
num_train_timesteps: 100
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
prediction_type: epsilon # epsilon / sample
|
||||
clip_sample: True
|
||||
clip_sample_range: 1.0
|
||||
|
||||
# Inference
|
||||
num_inference_steps: 10 # if not provided, defaults to `num_train_timesteps`
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: false
|
||||
@@ -1,93 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/xarm_lift_medium
|
||||
|
||||
training:
|
||||
offline_steps: 50000
|
||||
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
save_freq: 10000
|
||||
eval_freq: 5000
|
||||
log_freq: 100
|
||||
|
||||
online_steps: 50000
|
||||
online_rollout_n_episodes: 1
|
||||
online_rollout_batch_size: 1
|
||||
# Note: in FOWM `online_steps_between_rollouts` is actually dynamically set to match exactly the length of
|
||||
# the last sampled episode.
|
||||
online_steps_between_rollouts: 50
|
||||
online_sampling_ratio: 0.5
|
||||
online_env_seed: 10000
|
||||
# FOWM Push uses 10000 for `online_buffer_capacity`. Given that their maximum episode length for this task
|
||||
# is 25, 10000 is approx 400 of their episodes worth. Since our episodes are about 8 times longer, we'll use
|
||||
# 80000.
|
||||
online_buffer_capacity: 80000
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: tdmpc
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 2
|
||||
horizon: 5
|
||||
n_action_steps: 1
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.image: [3, 84, 84]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: null
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
state_encoder_hidden_dim: 256
|
||||
latent_dim: 50
|
||||
q_ensemble_size: 5
|
||||
mlp_dim: 512
|
||||
# Reinforcement learning.
|
||||
discount: 0.9
|
||||
|
||||
# Inference.
|
||||
use_mpc: true
|
||||
cem_iterations: 6
|
||||
max_std: 2.0
|
||||
min_std: 0.05
|
||||
n_gaussian_samples: 512
|
||||
n_pi_samples: 51
|
||||
uncertainty_regularizer_coeff: 1.0
|
||||
n_elites: 50
|
||||
elite_weighting_temperature: 0.5
|
||||
gaussian_mean_momentum: 0.1
|
||||
|
||||
# Training and loss computation.
|
||||
max_random_shift_ratio: 0.0476
|
||||
# Loss coefficients.
|
||||
reward_coeff: 0.5
|
||||
expectile_weight: 0.9
|
||||
value_coeff: 0.1
|
||||
consistency_coeff: 20.0
|
||||
advantage_scaling: 3.0
|
||||
pi_coeff: 0.5
|
||||
temporal_decay_coeff: 0.5
|
||||
# Target model.
|
||||
target_model_momentum: 0.995
|
||||
@@ -1,105 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# env=pusht \
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
# policy=tdmpc_pusht_keypoints \
|
||||
# eval.batch_size=50 \
|
||||
# eval.n_episodes=50 \
|
||||
# eval.use_async_envs=true \
|
||||
# device=cuda \
|
||||
# use_amp=true
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 0
|
||||
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 10000
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 40000
|
||||
online_buffer_seed_size: 0
|
||||
do_online_rollout_async: false
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: tdmpc
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 5
|
||||
n_action_steps: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
state_encoder_hidden_dim: 256
|
||||
latent_dim: 50
|
||||
q_ensemble_size: 5
|
||||
mlp_dim: 512
|
||||
# Reinforcement learning.
|
||||
discount: 0.98
|
||||
|
||||
# Inference.
|
||||
use_mpc: true
|
||||
cem_iterations: 6
|
||||
max_std: 2.0
|
||||
min_std: 0.05
|
||||
n_gaussian_samples: 512
|
||||
n_pi_samples: 51
|
||||
uncertainty_regularizer_coeff: 1.0
|
||||
n_elites: 50
|
||||
elite_weighting_temperature: 0.5
|
||||
gaussian_mean_momentum: 0.1
|
||||
|
||||
# Training and loss computation.
|
||||
max_random_shift_ratio: 0.0476
|
||||
# Loss coefficients.
|
||||
reward_coeff: 0.5
|
||||
expectile_weight: 0.9
|
||||
value_coeff: 0.1
|
||||
consistency_coeff: 20.0
|
||||
advantage_scaling: 3.0
|
||||
pi_coeff: 0.5
|
||||
temporal_decay_coeff: 0.5
|
||||
# Target model.
|
||||
target_model_momentum: 0.995
|
||||
@@ -1,103 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Defaults for training for the PushT dataset.
|
||||
|
||||
seed: 100000
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
override_dataset_stats:
|
||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||
observation.image:
|
||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
||||
# from the original codebase, but we should remove these and train our own pretrained model
|
||||
observation.state:
|
||||
min: [13.456424, 32.938293]
|
||||
max: [496.14618, 510.9579]
|
||||
action:
|
||||
min: [12.0, 25.0]
|
||||
max: [511.0, 511.0]
|
||||
|
||||
training:
|
||||
offline_steps: 250000
|
||||
online_steps: 0
|
||||
eval_freq: 25000
|
||||
save_freq: 25000
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
grad_clip_norm: 10
|
||||
lr: 1.0e-4
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
adam_betas: [0.95, 0.999]
|
||||
adam_eps: 1.0e-8
|
||||
adam_weight_decay: 1.0e-6
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
# VQ-BeT specific
|
||||
vqvae_lr: 1.0e-3
|
||||
n_vqvae_training_steps: 20000
|
||||
bet_weight_decay: 2e-4
|
||||
bet_learning_rate: 5.5e-5
|
||||
bet_betas: [0.9, 0.999]
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.action_chunk_size} - 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
policy:
|
||||
name: vqbet
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 5
|
||||
n_action_pred_token: 7
|
||||
action_chunk_size: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.image: mean_std
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
crop_shape: [84, 84]
|
||||
crop_is_random: True
|
||||
pretrained_backbone_weights: null
|
||||
use_group_norm: True
|
||||
spatial_softmax_num_keypoints: 32
|
||||
# VQ-VAE
|
||||
n_vqvae_training_steps: ${training.n_vqvae_training_steps}
|
||||
vqvae_n_embed: 16
|
||||
vqvae_embedding_dim: 256
|
||||
vqvae_enc_hidden_dim: 128
|
||||
# VQ-BeT
|
||||
gpt_block_size: 500
|
||||
gpt_input_dim: 512
|
||||
gpt_output_dim: 512
|
||||
gpt_n_layer: 8
|
||||
gpt_n_head: 8
|
||||
gpt_hidden_dim: 512
|
||||
dropout: 0.1
|
||||
mlp_hidden_dim: 1024
|
||||
offset_loss_weight: 10000.
|
||||
primary_code_loss_weight: 5.0
|
||||
secondary_code_loss_weight: 0.5
|
||||
bet_softmax_temperature: 0.1
|
||||
sequentially_select: False
|
||||
@@ -1,117 +0,0 @@
|
||||
# [Aloha: A Low-Cost Hardware for Bimanual Teleoperation](https://www.trossenrobotics.com/aloha-stationary)
|
||||
# https://aloha-2.github.io
|
||||
|
||||
# Requires installing extras packages
|
||||
# With pip: `pip install -e ".[dynamixel intelrealsense]"`
|
||||
# With poetry: `poetry install --sync --extras "dynamixel intelrealsense"`
|
||||
|
||||
# See [tutorial](https://github.com/huggingface/lerobot/blob/main/examples/9_use_aloha.md)
|
||||
|
||||
|
||||
_target_: lerobot.common.robot_devices.robots.manipulator.ManipulatorRobot
|
||||
robot_type: aloha
|
||||
# Specific to Aloha, LeRobot comes with default calibration files. Assuming the motors have been
|
||||
# properly assembled, no manual calibration step is expected. If you need to run manual calibration,
|
||||
# simply update this path to ".cache/calibration/aloha"
|
||||
calibration_dir: .cache/calibration/aloha_default
|
||||
|
||||
# /!\ FOR SAFETY, READ THIS /!\
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
|
||||
# When you feel more confident with teleoperation or running the policy, you can extend
|
||||
# this safety limit and even removing it by setting it to `null`.
|
||||
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
|
||||
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
|
||||
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
|
||||
max_relative_target: 5
|
||||
|
||||
leader_arms:
|
||||
left:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/ttyDXL_leader_left
|
||||
motors: # window_x
|
||||
# name: (index, model)
|
||||
waist: [1, xm430-w350]
|
||||
shoulder: [2, xm430-w350]
|
||||
shoulder_shadow: [3, xm430-w350]
|
||||
elbow: [4, xm430-w350]
|
||||
elbow_shadow: [5, xm430-w350]
|
||||
forearm_roll: [6, xm430-w350]
|
||||
wrist_angle: [7, xm430-w350]
|
||||
wrist_rotate: [8, xl430-w250]
|
||||
gripper: [9, xc430-w150]
|
||||
right:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/ttyDXL_leader_right
|
||||
motors: # window_x
|
||||
# name: (index, model)
|
||||
waist: [1, xm430-w350]
|
||||
shoulder: [2, xm430-w350]
|
||||
shoulder_shadow: [3, xm430-w350]
|
||||
elbow: [4, xm430-w350]
|
||||
elbow_shadow: [5, xm430-w350]
|
||||
forearm_roll: [6, xm430-w350]
|
||||
wrist_angle: [7, xm430-w350]
|
||||
wrist_rotate: [8, xl430-w250]
|
||||
gripper: [9, xc430-w150]
|
||||
|
||||
follower_arms:
|
||||
left:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/ttyDXL_follower_left
|
||||
motors:
|
||||
# name: [index, model]
|
||||
waist: [1, xm540-w270]
|
||||
shoulder: [2, xm540-w270]
|
||||
shoulder_shadow: [3, xm540-w270]
|
||||
elbow: [4, xm540-w270]
|
||||
elbow_shadow: [5, xm540-w270]
|
||||
forearm_roll: [6, xm540-w270]
|
||||
wrist_angle: [7, xm540-w270]
|
||||
wrist_rotate: [8, xm430-w350]
|
||||
gripper: [9, xm430-w350]
|
||||
right:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/ttyDXL_follower_right
|
||||
motors:
|
||||
# name: [index, model]
|
||||
waist: [1, xm540-w270]
|
||||
shoulder: [2, xm540-w270]
|
||||
shoulder_shadow: [3, xm540-w270]
|
||||
elbow: [4, xm540-w270]
|
||||
elbow_shadow: [5, xm540-w270]
|
||||
forearm_roll: [6, xm540-w270]
|
||||
wrist_angle: [7, xm540-w270]
|
||||
wrist_rotate: [8, xm430-w350]
|
||||
gripper: [9, xm430-w350]
|
||||
|
||||
# Troubleshooting: If one of your IntelRealSense cameras freeze during
|
||||
# data recording due to bandwidth limit, you might need to plug the camera
|
||||
# on another USB hub or PCIe card.
|
||||
cameras:
|
||||
cam_high:
|
||||
_target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera
|
||||
serial_number: 128422271347
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
cam_low:
|
||||
_target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera
|
||||
serial_number: 130322270656
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
cam_left_wrist:
|
||||
_target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera
|
||||
serial_number: 218622272670
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
cam_right_wrist:
|
||||
_target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera
|
||||
serial_number: 130322272300
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
@@ -1,53 +0,0 @@
|
||||
_target_: lerobot.common.robot_devices.robots.manipulator.ManipulatorRobot
|
||||
robot_type: koch
|
||||
calibration_dir: .cache/calibration/koch
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: null
|
||||
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
shoulder_lift: [2, "xl330-m077"]
|
||||
elbow_flex: [3, "xl330-m077"]
|
||||
wrist_flex: [4, "xl330-m077"]
|
||||
wrist_roll: [5, "xl330-m077"]
|
||||
gripper: [6, "xl330-m077"]
|
||||
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
shoulder_lift: [2, "xl430-w250"]
|
||||
elbow_flex: [3, "xl330-m288"]
|
||||
wrist_flex: [4, "xl330-m288"]
|
||||
wrist_roll: [5, "xl330-m288"]
|
||||
gripper: [6, "xl330-m288"]
|
||||
|
||||
cameras:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
|
||||
# ~ Koch specific settings ~
|
||||
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
|
||||
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||
gripper_open_degree: 35.156
|
||||
@@ -1,75 +0,0 @@
|
||||
_target_: lerobot.common.robot_devices.robots.manipulator.ManipulatorRobot
|
||||
robot_type: koch_bimanual
|
||||
calibration_dir: .cache/calibration/koch_bimanual
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: null
|
||||
|
||||
leader_arms:
|
||||
left:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem585A0085511
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
shoulder_lift: [2, "xl330-m077"]
|
||||
elbow_flex: [3, "xl330-m077"]
|
||||
wrist_flex: [4, "xl330-m077"]
|
||||
wrist_roll: [5, "xl330-m077"]
|
||||
gripper: [6, "xl330-m077"]
|
||||
right:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
shoulder_lift: [2, "xl330-m077"]
|
||||
elbow_flex: [3, "xl330-m077"]
|
||||
wrist_flex: [4, "xl330-m077"]
|
||||
wrist_roll: [5, "xl330-m077"]
|
||||
gripper: [6, "xl330-m077"]
|
||||
|
||||
follower_arms:
|
||||
left:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem585A0076891
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
shoulder_lift: [2, "xl430-w250"]
|
||||
elbow_flex: [3, "xl330-m288"]
|
||||
wrist_flex: [4, "xl330-m288"]
|
||||
wrist_roll: [5, "xl330-m288"]
|
||||
gripper: [6, "xl330-m288"]
|
||||
right:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
shoulder_lift: [2, "xl430-w250"]
|
||||
elbow_flex: [3, "xl330-m288"]
|
||||
wrist_flex: [4, "xl330-m288"]
|
||||
wrist_roll: [5, "xl330-m288"]
|
||||
gripper: [6, "xl330-m288"]
|
||||
|
||||
cameras:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
|
||||
# ~ Koch specific settings ~
|
||||
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
|
||||
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||
gripper_open_degree: 35.156
|
||||
@@ -1,56 +0,0 @@
|
||||
# [Moss v1 robot arm](https://github.com/jess-moss/moss-robot-arms)
|
||||
|
||||
# Requires installing extras packages
|
||||
# With pip: `pip install -e ".[feetech]"`
|
||||
# With poetry: `poetry install --sync --extras "feetech"`
|
||||
|
||||
# See [tutorial](https://github.com/huggingface/lerobot/blob/main/examples/11_use_moss.md)
|
||||
|
||||
_target_: lerobot.common.robot_devices.robots.manipulator.ManipulatorRobot
|
||||
robot_type: moss
|
||||
calibration_dir: .cache/calibration/moss
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: null
|
||||
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
|
||||
port: /dev/tty.usbmodem58760431091
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "sts3215"]
|
||||
shoulder_lift: [2, "sts3215"]
|
||||
elbow_flex: [3, "sts3215"]
|
||||
wrist_flex: [4, "sts3215"]
|
||||
wrist_roll: [5, "sts3215"]
|
||||
gripper: [6, "sts3215"]
|
||||
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
|
||||
port: /dev/tty.usbmodem58760431191
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "sts3215"]
|
||||
shoulder_lift: [2, "sts3215"]
|
||||
elbow_flex: [3, "sts3215"]
|
||||
wrist_flex: [4, "sts3215"]
|
||||
wrist_roll: [5, "sts3215"]
|
||||
gripper: [6, "sts3215"]
|
||||
|
||||
cameras:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
@@ -1,56 +0,0 @@
|
||||
# [SO-100 robot arm](https://github.com/TheRobotStudio/SO-ARM100)
|
||||
|
||||
# Requires installing extras packages
|
||||
# With pip: `pip install -e ".[feetech]"`
|
||||
# With poetry: `poetry install --sync --extras "feetech"`
|
||||
|
||||
# See [tutorial](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md)
|
||||
|
||||
_target_: lerobot.common.robot_devices.robots.manipulator.ManipulatorRobot
|
||||
robot_type: so100
|
||||
calibration_dir: .cache/calibration/so100
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: null
|
||||
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
|
||||
port: /dev/tty.usbmodem585A0077581
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "sts3215"]
|
||||
shoulder_lift: [2, "sts3215"]
|
||||
elbow_flex: [3, "sts3215"]
|
||||
wrist_flex: [4, "sts3215"]
|
||||
wrist_roll: [5, "sts3215"]
|
||||
gripper: [6, "sts3215"]
|
||||
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
|
||||
port: /dev/tty.usbmodem585A0080971
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "sts3215"]
|
||||
shoulder_lift: [2, "sts3215"]
|
||||
elbow_flex: [3, "sts3215"]
|
||||
wrist_flex: [4, "sts3215"]
|
||||
wrist_roll: [5, "sts3215"]
|
||||
gripper: [6, "sts3215"]
|
||||
|
||||
cameras:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
@@ -1,33 +0,0 @@
|
||||
# [Stretch3 from Hello Robot](https://hello-robot.com/stretch-3-product)
|
||||
|
||||
# Requires installing extras packages
|
||||
# With pip: `pip install -e ".[stretch]"`
|
||||
# With poetry: `poetry install --sync --extras "stretch"`
|
||||
|
||||
# See [tutorial](https://github.com/huggingface/lerobot/blob/main/examples/8_use_stretch.md)
|
||||
|
||||
|
||||
_target_: lerobot.common.robot_devices.robots.stretch.StretchRobot
|
||||
robot_type: stretch3
|
||||
|
||||
cameras:
|
||||
navigation:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: /dev/hello-nav-head-camera
|
||||
fps: 10
|
||||
width: 1280
|
||||
height: 720
|
||||
rotation: -90
|
||||
head:
|
||||
_target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera.init_from_name
|
||||
name: Intel RealSense D435I
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
rotation: 90
|
||||
wrist:
|
||||
_target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera.init_from_name
|
||||
name: Intel RealSense D405
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
236
lerobot/configs/train.py
Normal file
236
lerobot/configs/train.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import datetime as dt
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.common import envs
|
||||
from lerobot.common.optim import OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OfflineConfig:
|
||||
steps: int = 100_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class OnlineConfig:
|
||||
"""
|
||||
The online training loop looks something like:
|
||||
|
||||
```python
|
||||
for i in range(steps):
|
||||
do_online_rollout_and_update_online_buffer()
|
||||
for j in range(steps_between_rollouts):
|
||||
batch = next(dataloader_with_offline_and_online_data)
|
||||
loss = policy(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
Note that the online training loop adopts most of the options from the offline loop unless specified
|
||||
otherwise.
|
||||
"""
|
||||
|
||||
steps: int = 0
|
||||
# How many episodes to collect at once when we reach the online rollout part of the training loop.
|
||||
rollout_n_episodes: int = 1
|
||||
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
|
||||
# the policy. Ideally you should set this to by an even divisor of rollout_n_episodes.
|
||||
rollout_batch_size: int = 1
|
||||
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
|
||||
steps_between_rollouts: int | None = None
|
||||
# The proportion of online samples (vs offline samples) to include in the online training batches.
|
||||
sampling_ratio: float = 0.5
|
||||
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
|
||||
env_seed: int | None = None
|
||||
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
|
||||
# FIFO.
|
||||
buffer_capacity: int | None = None
|
||||
# The minimum number of frames to have in the online buffer before commencing online training.
|
||||
# If buffer_seed_size > rollout_n_episodes, the rollout will be run multiple times until the
|
||||
# seed size condition is satisfied.
|
||||
buffer_seed_size: int = 0
|
||||
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
|
||||
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
|
||||
# + eval + environment rendering simultaneously.
|
||||
do_rollout_async: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.steps == 0:
|
||||
return
|
||||
|
||||
if self.steps_between_rollouts is None:
|
||||
raise ValueError(
|
||||
"'steps_between_rollouts' must be set to a positive integer, but it is currently None."
|
||||
)
|
||||
if self.env_seed is None:
|
||||
raise ValueError("'env_seed' must be set to a positive integer, but it is currently None.")
|
||||
if self.buffer_capacity is None:
|
||||
raise ValueError("'buffer_capacity' must be set to a positive integer, but it is currently None.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||
# `dir` is the directory of an existing run with at least one checkpoint in it.
|
||||
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
||||
# regardless of what's provided with the training command at the time of resumption.
|
||||
resume: bool = False
|
||||
device: str | None = None # cuda | cpu | mp
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||
# AND for the evaluation environments.
|
||||
seed: int | None = 1000
|
||||
# Number of workers for the dataloader.
|
||||
num_workers: int = 4
|
||||
batch_size: int = 8
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
save_checkpoint: bool = True
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: int = 20_000
|
||||
offline: OfflineConfig = field(default_factory=OfflineConfig)
|
||||
online: OnlineConfig = field(default_factory=OnlineConfig)
|
||||
use_policy_training_preset: bool = True
|
||||
optimizer: OptimizerConfig | None = None
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
self.checkpoint_path = None
|
||||
|
||||
def validate(self):
|
||||
if not self.device:
|
||||
logging.warning("No device specified, trying to infer device automatically")
|
||||
device = auto_select_torch_device()
|
||||
self.device = device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError("A config_path is expected when resuming a run.")
|
||||
policy_path = Path(config_path).parent
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.checkpoint_path = policy_path.parent
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{self.policy.type}"
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
|
||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||
raise FileExistsError(
|
||||
f"Output directory {self.output_dir} alreay exists and resume is {self.resume}. "
|
||||
f"Please change your output directory so that {self.output_dir} is not overwritten."
|
||||
)
|
||||
elif not self.output_dir:
|
||||
now = dt.datetime.now()
|
||||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||
self.output_dir = Path("outputs/train") / train_dir
|
||||
|
||||
if self.online.steps > 0:
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||
if self.env is None:
|
||||
raise ValueError("An environment is required for online training")
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
elif self.use_policy_training_preset and not self.resume:
|
||||
self.optimizer = self.policy.get_optimizer_preset()
|
||||
self.scheduler = self.policy.get_scheduler_preset()
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: Type["TrainPipelineConfig"],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**kwargs,
|
||||
) -> "TrainPipelineConfig":
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
if Path(model_id).is_dir():
|
||||
if TRAIN_CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
|
||||
else:
|
||||
print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
elif Path(model_id).is_file():
|
||||
config_file = model_id
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=TRAIN_CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
cfg = draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
return cfg
|
||||
28
lerobot/configs/types.py
Normal file
28
lerobot/configs/types.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Note: We subclass str so that serialization is straightforward
|
||||
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class FeatureType(str, Enum):
|
||||
STATE = "STATE"
|
||||
VISUAL = "VISUAL"
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
MIN_MAX = "MIN_MAX"
|
||||
MEAN_STD = "MEAN_STD"
|
||||
IDENTITY = "IDENTITY"
|
||||
|
||||
|
||||
class DictLike(Protocol):
|
||||
def __getitem__(self, key: Any) -> Any: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyFeature:
|
||||
type: FeatureType
|
||||
shape: tuple
|
||||
@@ -16,28 +16,42 @@ import argparse
|
||||
import time
|
||||
|
||||
|
||||
def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
def get_motor_bus_cls(brand: str) -> tuple:
|
||||
if brand == "feetech":
|
||||
from lerobot.common.robot_devices.motors.feetech import MODEL_BAUDRATE_TABLE
|
||||
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
|
||||
from lerobot.common.robot_devices.motors.feetech import (
|
||||
SCS_SERIES_BAUDRATE_TABLE as SERIES_BAUDRATE_TABLE,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
SCS_SERIES_BAUDRATE_TABLE,
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus as MotorsBusClass
|
||||
|
||||
return FeetechMotorsBusConfig, FeetechMotorsBus, MODEL_BAUDRATE_TABLE, SCS_SERIES_BAUDRATE_TABLE
|
||||
|
||||
elif brand == "dynamixel":
|
||||
from lerobot.common.robot_devices.motors.dynamixel import MODEL_BAUDRATE_TABLE
|
||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||
from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
X_SERIES_BAUDRATE_TABLE as SERIES_BAUDRATE_TABLE,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
X_SERIES_BAUDRATE_TABLE,
|
||||
DynamixelMotorsBus,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus as MotorsBusClass
|
||||
|
||||
return DynamixelMotorsBusConfig, DynamixelMotorsBus, MODEL_BAUDRATE_TABLE, X_SERIES_BAUDRATE_TABLE
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Currently we do not support this motor brand: {brand}. We currently support feetech and dynamixel motors."
|
||||
)
|
||||
|
||||
|
||||
def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls(
|
||||
brand
|
||||
)
|
||||
|
||||
# Check if the provided model exists in the model_baud_rate_table
|
||||
if model not in MODEL_BAUDRATE_TABLE:
|
||||
if model not in model_baudrate_table:
|
||||
raise ValueError(
|
||||
f"Invalid model '{model}' for brand '{brand}'. Supported models: {list(MODEL_BAUDRATE_TABLE.keys())}"
|
||||
f"Invalid model '{model}' for brand '{brand}'. Supported models: {list(model_baudrate_table.keys())}"
|
||||
)
|
||||
|
||||
# Setup motor names, indices, and models
|
||||
@@ -45,8 +59,10 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument
|
||||
motor_model = model # Use the motor model passed via argument
|
||||
|
||||
config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)})
|
||||
|
||||
# Initialize the MotorBus with the correct port and motor configurations
|
||||
motor_bus = MotorsBusClass(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)})
|
||||
motor_bus = motor_bus_cls(config=config)
|
||||
|
||||
# Try to connect to the motor bus and handle any connection-specific errors
|
||||
try:
|
||||
@@ -59,7 +75,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
# Motor bus is connected, proceed with the rest of the operations
|
||||
try:
|
||||
print("Scanning all baudrates and motor indices")
|
||||
all_baudrates = set(SERIES_BAUDRATE_TABLE.values())
|
||||
all_baudrates = set(series_baudrate_table.values())
|
||||
motor_index = -1 # Set the motor index to an out-of-range value.
|
||||
|
||||
for baudrate in all_baudrates:
|
||||
@@ -76,6 +92,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
"Error: More than one motor ID detected. This script is designed to only handle one motor at a time. Please disconnect all but one motor."
|
||||
)
|
||||
motor_index = present_ids[0]
|
||||
break
|
||||
|
||||
if motor_index == -1:
|
||||
raise ValueError("No motors detected. Please ensure you have one motor connected.")
|
||||
@@ -88,7 +105,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
|
||||
if baudrate != baudrate_des:
|
||||
print(f"Setting its baudrate to {baudrate_des}")
|
||||
baudrate_idx = list(SERIES_BAUDRATE_TABLE.values()).index(baudrate_des)
|
||||
baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des)
|
||||
|
||||
# The write can fail, so we allow retries
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
|
||||
@@ -102,7 +119,8 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
raise OSError("Failed to write baudrate.")
|
||||
|
||||
print(f"Setting its index to desired index {motor_idx_des}")
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
||||
if brand == "feetech":
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
|
||||
|
||||
present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
|
||||
|
||||
@@ -8,30 +8,42 @@ Examples of usage:
|
||||
|
||||
- Recalibrate your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=calibrate
|
||||
```
|
||||
|
||||
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
|
||||
# Remove the cameras from the robot definition. They are not used in 'teleoperate' anyway.
|
||||
python lerobot/scripts/control_robot.py teleoperate --robot-overrides '~cameras'
|
||||
# Add the cameras from the robot definition to visualize them:
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--fps 30
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=teleoperate \
|
||||
--control.fps=30
|
||||
```
|
||||
|
||||
- Record one episode in order to test replay:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--repo-id $USER/koch_test \
|
||||
--num-episodes 1 \
|
||||
--run-compute-stats 0
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=$USER/koch_test \
|
||||
--control.num_episodes=1 \
|
||||
--control.push_to_hub=True
|
||||
```
|
||||
|
||||
- Visualize dataset:
|
||||
@@ -44,21 +56,25 @@ python lerobot/scripts/visualize_dataset.py \
|
||||
- Replay this test episode:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--fps 30 \
|
||||
--repo-id $USER/koch_test \
|
||||
--episode 0
|
||||
--robot.type=so100 \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=$USER/koch_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
- Record a full dataset in order to train a policy, with 2 seconds of warmup,
|
||||
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--repo-id $USER/koch_pick_place_lego \
|
||||
--num-episodes 50 \
|
||||
--warmup-time-s 2 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 10
|
||||
--robot.type=so100 \
|
||||
--control.type=record \
|
||||
--control.fps 30 \
|
||||
--control.repo_id=$USER/koch_pick_place_lego \
|
||||
--control.num_episodes=50 \
|
||||
--control.warmup_time_s=2 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=10
|
||||
```
|
||||
|
||||
**NOTE**: You can use your keyboard to control data recording flow.
|
||||
@@ -68,44 +84,55 @@ python lerobot/scripts/control_robot.py record \
|
||||
- Tap escape key 'esc' to stop the data recording.
|
||||
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
||||
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--resume 1`.
|
||||
If the dataset you want to extend is not on the hub, you also need to add `--local-files-only 1`.
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`.
|
||||
If the dataset you want to extend is not on the hub, you also need to add `--control.local_files_only=true`.
|
||||
|
||||
- Train on this dataset with the ACT policy:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act_koch_real \
|
||||
env=koch_real \
|
||||
dataset_repo_id=$USER/koch_pick_place_lego \
|
||||
hydra.run.dir=outputs/train/act_koch_real
|
||||
--dataset.repo_id=${HF_USER}/koch_pick_place_lego \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_koch_pick_place_lego \
|
||||
--job_name=act_koch_pick_place_lego \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
- Run the pretrained policy on the robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--repo-id $USER/eval_act_koch_real \
|
||||
--num-episodes 10 \
|
||||
--warmup-time-s 2 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 10
|
||||
-p outputs/train/act_koch_real/checkpoints/080000/pretrained_model
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=$USER/eval_act_koch_pick_place_lego \
|
||||
--control.num_episodes=10 \
|
||||
--control.warmup_time_s=2 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_koch_pick_place_lego/checkpoints/080000/pretrained_model
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
ControlPipelineConfig,
|
||||
RecordControlConfig,
|
||||
ReplayControlConfig,
|
||||
TeleoperateControlConfig,
|
||||
)
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
has_method,
|
||||
init_keyboard_listener,
|
||||
init_policy,
|
||||
log_control_info,
|
||||
record_episode,
|
||||
reset_environment,
|
||||
@@ -114,10 +141,10 @@ from lerobot.common.robot_devices.control_utils import (
|
||||
stop_recording,
|
||||
warmup_record,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
|
||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say, none_or_int
|
||||
from lerobot.common.utils.utils import has_method, init_logging, log_say
|
||||
from lerobot.configs import parser
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
@@ -125,7 +152,7 @@ from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say,
|
||||
|
||||
|
||||
@safe_disconnect
|
||||
def calibrate(robot: Robot, arms: list[str] | None):
|
||||
def calibrate(robot: Robot, cfg: CalibrateControlConfig):
|
||||
# TODO(aliberts): move this code in robots' classes
|
||||
if robot.robot_type.startswith("stretch"):
|
||||
if not robot.is_connected:
|
||||
@@ -134,9 +161,7 @@ def calibrate(robot: Robot, arms: list[str] | None):
|
||||
robot.home()
|
||||
return
|
||||
|
||||
if arms is None:
|
||||
arms = robot.available_arms
|
||||
|
||||
arms = robot.available_arms if cfg.arms is None else cfg.arms
|
||||
unknown_arms = [arm_id for arm_id in arms if arm_id not in robot.available_arms]
|
||||
available_arms_str = " ".join(robot.available_arms)
|
||||
unknown_arms_str = " ".join(unknown_arms)
|
||||
@@ -171,91 +196,50 @@ def calibrate(robot: Robot, arms: list[str] | None):
|
||||
|
||||
|
||||
@safe_disconnect
|
||||
def teleoperate(
|
||||
robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False
|
||||
):
|
||||
def teleoperate(robot: Robot, cfg: TeleoperateControlConfig):
|
||||
control_loop(
|
||||
robot,
|
||||
control_time_s=teleop_time_s,
|
||||
fps=fps,
|
||||
control_time_s=cfg.teleop_time_s,
|
||||
fps=cfg.fps,
|
||||
teleoperate=True,
|
||||
display_cameras=display_cameras,
|
||||
display_cameras=cfg.display_cameras,
|
||||
)
|
||||
|
||||
|
||||
@safe_disconnect
|
||||
def record(
|
||||
robot: Robot,
|
||||
root: Path,
|
||||
repo_id: str,
|
||||
single_task: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
fps: int | None = None,
|
||||
warmup_time_s: int | float = 2,
|
||||
episode_time_s: int | float = 10,
|
||||
reset_time_s: int | float = 5,
|
||||
num_episodes: int = 50,
|
||||
video: bool = True,
|
||||
run_compute_stats: bool = True,
|
||||
push_to_hub: bool = True,
|
||||
tags: list[str] | None = None,
|
||||
num_image_writer_processes: int = 0,
|
||||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
resume: bool = False,
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
local_files_only: bool = False,
|
||||
cfg: RecordControlConfig,
|
||||
) -> LeRobotDataset:
|
||||
# TODO(rcadene): Add option to record logs
|
||||
listener = None
|
||||
events = None
|
||||
policy = None
|
||||
device = None
|
||||
use_amp = None
|
||||
|
||||
if single_task:
|
||||
task = single_task
|
||||
else:
|
||||
raise NotImplementedError("Only single-task recording is supported for now")
|
||||
|
||||
# Load pretrained policy
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
|
||||
if fps is None:
|
||||
fps = policy_fps
|
||||
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
|
||||
elif fps != policy_fps:
|
||||
logging.warning(
|
||||
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
|
||||
)
|
||||
|
||||
if resume:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=root,
|
||||
local_files_only=local_files_only,
|
||||
cfg.repo_id,
|
||||
root=cfg.root,
|
||||
local_files_only=cfg.local_files_only,
|
||||
)
|
||||
dataset.start_image_writer(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
|
||||
if len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.num_image_writer_processes,
|
||||
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
sanity_check_dataset_name(cfg.repo_id, cfg.policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id,
|
||||
fps,
|
||||
root=root,
|
||||
cfg.repo_id,
|
||||
cfg.fps,
|
||||
root=cfg.root,
|
||||
robot=robot,
|
||||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
use_videos=cfg.video,
|
||||
image_writer_processes=cfg.num_image_writer_processes,
|
||||
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
@@ -266,33 +250,28 @@ def record(
|
||||
# 2. give times to the robot devices to connect and start synchronizing,
|
||||
# 3. place the cameras windows on screen
|
||||
enable_teleoperation = policy is None
|
||||
log_say("Warmup record", play_sounds)
|
||||
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps)
|
||||
log_say("Warmup record", cfg.play_sounds)
|
||||
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
|
||||
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
recorded_episodes = 0
|
||||
while True:
|
||||
if recorded_episodes >= num_episodes:
|
||||
if recorded_episodes >= cfg.num_episodes:
|
||||
break
|
||||
|
||||
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
|
||||
# input() messes with them.
|
||||
# if multi_task:
|
||||
# task = input("Enter your task description: ")
|
||||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
record_episode(
|
||||
dataset=dataset,
|
||||
robot=robot,
|
||||
events=events,
|
||||
episode_time_s=episode_time_s,
|
||||
display_cameras=display_cameras,
|
||||
episode_time_s=cfg.episode_time_s,
|
||||
display_cameras=cfg.display_cameras,
|
||||
policy=policy,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
device=cfg.device,
|
||||
use_amp=cfg.use_amp,
|
||||
fps=cfg.fps,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
@@ -300,59 +279,56 @@ def record(
|
||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < num_episodes - 1) or events["rerecord_episode"]
|
||||
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
log_say("Reset the environment", cfg.play_sounds)
|
||||
reset_environment(robot, events, cfg.reset_time_s)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
log_say("Re-record episode", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode(task)
|
||||
dataset.save_episode(cfg.single_task)
|
||||
recorded_episodes += 1
|
||||
|
||||
if events["stop_recording"]:
|
||||
break
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, display_cameras)
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, cfg.display_cameras)
|
||||
|
||||
if run_compute_stats:
|
||||
if cfg.run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
|
||||
dataset.consolidate(run_compute_stats)
|
||||
dataset.consolidate(cfg.run_compute_stats)
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub(tags=tags)
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||
|
||||
log_say("Exiting", play_sounds)
|
||||
log_say("Exiting", cfg.play_sounds)
|
||||
return dataset
|
||||
|
||||
|
||||
@safe_disconnect
|
||||
def replay(
|
||||
robot: Robot,
|
||||
root: Path,
|
||||
repo_id: str,
|
||||
episode: int,
|
||||
fps: int | None = None,
|
||||
play_sounds: bool = True,
|
||||
local_files_only: bool = False,
|
||||
cfg: ReplayControlConfig,
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.repo_id, root=cfg.root, episodes=[cfg.episode], local_files_only=cfg.local_files_only
|
||||
)
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", play_sounds, blocking=True)
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
@@ -360,216 +336,33 @@ def replay(
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
busy_wait(1 / cfg.fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
log_control_info(robot, dt_s, fps=cfg.fps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="mode", required=True)
|
||||
|
||||
# Set common options for all the subparsers
|
||||
base_parser = argparse.ArgumentParser(add_help=False)
|
||||
base_parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
base_parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
parser_calib = subparsers.add_parser("calibrate", parents=[base_parser])
|
||||
parser_calib.add_argument(
|
||||
"--arms",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="List of arms to calibrate (e.g. `--arms left_follower right_follower left_leader`)",
|
||||
)
|
||||
|
||||
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||
parser_teleop.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_teleop.add_argument(
|
||||
"--display-cameras",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Display all cameras on screen (set to 1 to display or 0).",
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
task_args = parser_record.add_mutually_exclusive_group(required=True)
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--single-task",
|
||||
type=str,
|
||||
help="A short but accurate description of the task performed during the recording.",
|
||||
)
|
||||
# TODO(aliberts): add multi-task support
|
||||
# task_args.add_argument(
|
||||
# "--multi-task",
|
||||
# type=int,
|
||||
# help="You will need to enter the task performed at the start of each episode.",
|
||||
# )
|
||||
parser_record.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--local-files-only",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--warmup-time-s",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--episode-time-s",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Number of seconds for data recording for each episode.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--reset-time-s",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Number of seconds for resetting the environment after each episode.",
|
||||
)
|
||||
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
|
||||
parser_record.add_argument(
|
||||
"--run-compute-stats",
|
||||
type=int,
|
||||
default=1,
|
||||
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--push-to-hub",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Upload dataset to Hugging Face hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writer-processes",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only; "
|
||||
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
|
||||
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
|
||||
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writer-threads-per-camera",
|
||||
type=int,
|
||||
default=4,
|
||||
help=(
|
||||
"Number of threads writing the frames as png images on disk, per camera. "
|
||||
"Too many threads might cause unstable teleoperation fps due to main thread being blocked. "
|
||||
"Not enough threads might cause low camera fps."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--resume",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Resume recording on an existing dataset.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--policy-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--local-files-only",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@parser.wrap()
|
||||
def control_robot(cfg: ControlPipelineConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
control_mode = args.mode
|
||||
robot_path = args.robot_path
|
||||
robot_overrides = args.robot_overrides
|
||||
kwargs = vars(args)
|
||||
del kwargs["mode"]
|
||||
del kwargs["robot_path"]
|
||||
del kwargs["robot_overrides"]
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
|
||||
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
|
||||
if control_mode == "calibrate":
|
||||
calibrate(robot, **kwargs)
|
||||
|
||||
elif control_mode == "teleoperate":
|
||||
teleoperate(robot, **kwargs)
|
||||
|
||||
elif control_mode == "record":
|
||||
record(robot, **kwargs)
|
||||
|
||||
elif control_mode == "replay":
|
||||
replay(robot, **kwargs)
|
||||
if isinstance(cfg.control, CalibrateControlConfig):
|
||||
calibrate(robot, cfg.control)
|
||||
elif isinstance(cfg.control, TeleoperateControlConfig):
|
||||
teleoperate(robot, cfg.control)
|
||||
elif isinstance(cfg.control, RecordControlConfig):
|
||||
record(robot, cfg.control)
|
||||
elif isinstance(cfg.control, ReplayControlConfig):
|
||||
replay(robot, cfg.control)
|
||||
|
||||
if robot.is_connected:
|
||||
# Disconnect manually to avoid a "Core dump" during process
|
||||
# termination due to camera threads not properly exiting.
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
control_robot()
|
||||
|
||||
@@ -90,11 +90,12 @@ from lerobot.common.robot_devices.control_utils import (
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
stop_recording,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say
|
||||
|
||||
raise NotImplementedError("This script is currently deactivated")
|
||||
|
||||
DEFAULT_FEATURES = {
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
@@ -227,7 +228,7 @@ def record(
|
||||
shape = env.observation_space[key].shape
|
||||
if not key.startswith("observation.image."):
|
||||
key = "observation.image." + key
|
||||
features[key] = {"dtype": "video", "names": ["channel", "height", "width"], "shape": shape}
|
||||
features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape}
|
||||
|
||||
for key, obs_key in state_keys_dict.items():
|
||||
features[key] = {
|
||||
@@ -504,7 +505,7 @@ if __name__ == "__main__":
|
||||
|
||||
# make gym env
|
||||
env_cfg = init_hydra_config(env_config_path)
|
||||
importlib.import_module(f"gym_{env_cfg.env.name}")
|
||||
importlib.import_module(f"gym_{env_cfg.env.type}")
|
||||
|
||||
def env_constructor():
|
||||
return gym.make(env_cfg.env.handle, disable_env_checker=True, **env_cfg.env.gym)
|
||||
@@ -515,6 +516,7 @@ if __name__ == "__main__":
|
||||
if control_mode in ["teleoperate", "record"]:
|
||||
# make robot
|
||||
robot_overrides = ["~cameras", "~follower_arms"]
|
||||
# TODO(rcadene): remove
|
||||
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
@@ -21,67 +21,69 @@ You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/di
|
||||
for 10 episodes.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval.py -p lerobot/diffusion_pusht eval.n_episodes=10
|
||||
python lerobot/scripts/eval.py \
|
||||
--policy.path=lerobot/diffusion_pusht \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
--eval.n_episodes=10 \
|
||||
--use_amp=false \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval.py \
|
||||
-p outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
|
||||
eval.n_episodes=10
|
||||
--policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
--eval.n_episodes=10 \
|
||||
--use_amp=false \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
Note that in both examples, the repo/folder should contain at least `config.json`, `config.yaml` and
|
||||
`model.safetensors`.
|
||||
Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.
|
||||
|
||||
Note the formatting for providing the number of episodes. Generally, you may provide any number of arguments
|
||||
with `qualified.parameter.name=value`. In this case, the parameter eval.n_episodes appears as `n_episodes`
|
||||
nested under `eval` in the `config.yaml` found at
|
||||
https://huggingface.co/lerobot/diffusion_pusht/tree/main.
|
||||
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from datetime import datetime as dt
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: Policy,
|
||||
policy: PreTrainedPolicy,
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
@@ -123,6 +125,9 @@ def rollout(
|
||||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
|
||||
if hasattr(policy, "use_ema_modules"):
|
||||
policy.use_ema_modules()
|
||||
|
||||
observation, info = env.reset(seed=seeds)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
@@ -203,12 +208,15 @@ def rollout(
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret["observation"] = stacked_observations
|
||||
|
||||
if hasattr(policy, "use_original_modules"):
|
||||
policy.use_original_modules()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def eval_policy(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: torch.nn.Module,
|
||||
policy: PreTrainedPolicy,
|
||||
n_episodes: int,
|
||||
max_episodes_rendered: int = 0,
|
||||
videos_dir: Path | None = None,
|
||||
@@ -232,7 +240,11 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0 and not videos_dir:
|
||||
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||
|
||||
assert isinstance(policy, Policy)
|
||||
if not isinstance(policy, PreTrainedPolicy):
|
||||
raise ValueError(
|
||||
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
policy.eval()
|
||||
|
||||
@@ -442,66 +454,43 @@ def _compile_episode_data(
|
||||
return data_dict
|
||||
|
||||
|
||||
def main(
|
||||
pretrained_policy_path: Path | None = None,
|
||||
hydra_cfg_path: str | None = None,
|
||||
out_dir: str | None = None,
|
||||
config_overrides: list[str] | None = None,
|
||||
):
|
||||
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
||||
if pretrained_policy_path is not None:
|
||||
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
|
||||
else:
|
||||
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
||||
|
||||
if hydra_cfg.eval.batch_size > hydra_cfg.eval.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
f"({hydra_cfg.eval.batch_size} > {hydra_cfg.eval.n_episodes}). As a result, {hydra_cfg.eval.batch_size} "
|
||||
f"eval environments will be instantiated, but only {hydra_cfg.eval.n_episodes} will be used. "
|
||||
"This might significantly slow down evaluation. To fix this, you should update your command "
|
||||
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={hydra_cfg.eval.batch_size}`), "
|
||||
f"or lower the batch size (e.g. `eval.batch_size={hydra_cfg.eval.n_episodes}`)."
|
||||
)
|
||||
|
||||
if out_dir is None:
|
||||
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
@parser.wrap()
|
||||
def eval(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
log_output_dir(out_dir)
|
||||
log_output_dir(cfg.output_dir)
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(hydra_cfg)
|
||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Making policy.")
|
||||
if hydra_cfg_path is None:
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
|
||||
else:
|
||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
device=device,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
hydra_cfg.eval.n_episodes,
|
||||
cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(out_dir) / "videos",
|
||||
start_seed=hydra_cfg.seed,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
|
||||
# Save info
|
||||
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
||||
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
env.close()
|
||||
@@ -509,76 +498,6 @@ def main(
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
return pretrained_policy_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
help=(
|
||||
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
"Where to save the evaluation outputs. If not provided, outputs are saved in "
|
||||
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.pretrained_policy_name_or_path is None:
|
||||
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
||||
else:
|
||||
pretrained_policy_path = get_pretrained_policy_path(
|
||||
args.pretrained_policy_name_or_path, revision=args.revision
|
||||
)
|
||||
|
||||
main(
|
||||
pretrained_policy_path=pretrained_policy_path,
|
||||
out_dir=args.out_dir,
|
||||
config_overrides=args.overrides,
|
||||
)
|
||||
eval()
|
||||
|
||||
71
lerobot/scripts/push_pretrained.py
Normal file
71
lerobot/scripts/push_pretrained.py
Normal file
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Once you have trained a policy with our training script (lerobot/scripts/train.py), use this script to push it
|
||||
to the hub.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/push_pretrained.py \
|
||||
--pretrained_path=outputs/train/act_aloha_sim_transfer_cube_human/checkpoints/last/pretrained_model \
|
||||
--repo_id=lerobot/act_aloha_sim_transfer_cube_human
|
||||
```
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
|
||||
@dataclass
|
||||
class PushPreTrainedConfig:
|
||||
pretrained_path: Path
|
||||
repo_id: str
|
||||
branch: str | None = None
|
||||
private: bool = False
|
||||
exist_ok: bool = False
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def main(cfg: PushPreTrainedConfig):
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
repo_id=cfg.repo_id,
|
||||
private=cfg.private,
|
||||
repo_type="model",
|
||||
exist_ok=cfg.exist_ok,
|
||||
)
|
||||
if cfg.branch:
|
||||
hub_api.create_branch(
|
||||
repo_id=cfg.repo_id,
|
||||
branch=cfg.branch,
|
||||
repo_type="model",
|
||||
exist_ok=cfg.exist_ok,
|
||||
)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=cfg.repo_id,
|
||||
folder_path=cfg.pretrained_path,
|
||||
repo_type="model",
|
||||
revision=cfg.branch,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -18,92 +18,37 @@ import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from threading import Lock
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.amp import GradScaler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.optim.factory import load_training_state, make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_dtype,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
has_method,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def make_optimizer_and_scheduler(cfg, policy):
|
||||
if cfg.policy.name == "act":
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": cfg.training.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
||||
)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "diffusion":
|
||||
optimizer = torch.optim.Adam(
|
||||
policy.diffusion.parameters(),
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
cfg.training.adam_weight_decay,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
cfg.training.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||
num_training_steps=cfg.training.offline_steps,
|
||||
)
|
||||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
|
||||
optimizer = VQBeTOptimizer(policy, cfg)
|
||||
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def update_policy(
|
||||
policy,
|
||||
batch,
|
||||
@@ -142,10 +87,14 @@ def update_policy(
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if hasattr(policy, "update_ema_modules"):
|
||||
policy.update_ema_modules()
|
||||
|
||||
# Step through pytorch scheduler at every batch instead of epoch
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if isinstance(policy, PolicyWithUpdate):
|
||||
if has_method(policy, "update"):
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
|
||||
@@ -161,7 +110,9 @@ def update_policy(
|
||||
return info
|
||||
|
||||
|
||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
|
||||
def log_train_info(
|
||||
logger: Logger, info: dict, step: int, cfg: TrainPipelineConfig, dataset: LeRobotDataset, is_online: bool
|
||||
):
|
||||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
@@ -170,7 +121,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
|
||||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.training.batch_size
|
||||
num_samples = (step + 1) * cfg.batch_size
|
||||
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_frames
|
||||
@@ -207,7 +158,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
|
||||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.training.batch_size
|
||||
num_samples = (step + 1) * cfg.batch_size
|
||||
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_frames
|
||||
@@ -234,74 +185,17 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
|
||||
logger.log_dict(info, step, mode="eval")
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||
|
||||
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
||||
# to check for any differences between the provided config and the checkpoint's config.
|
||||
if cfg.resume:
|
||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
# Get the configuration file from the last checkpoint.
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
# Log a warning about differences between the checkpoint configuration and the provided
|
||||
# configuration.
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
"At least one difference was detected between the checkpoint configuration and "
|
||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
||||
"takes precedence.",
|
||||
)
|
||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
||||
cfg = checkpoint_cfg
|
||||
cfg.resume = True
|
||||
elif Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
|
||||
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
|
||||
)
|
||||
|
||||
if cfg.eval.batch_size > cfg.eval.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
|
||||
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
|
||||
"This might significantly slow down evaluation. To fix this, you should update your command "
|
||||
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
|
||||
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
|
||||
)
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
logger = Logger(cfg)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
if cfg.seed is not None:
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
@@ -309,65 +203,59 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_dataset")
|
||||
logging.info("Creating dataset")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
|
||||
)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env")
|
||||
eval_env = make_env(cfg)
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
||||
|
||||
logging.info("make_policy")
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
cfg=cfg.policy,
|
||||
device=device,
|
||||
ds_meta=offline_dataset.meta,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(enabled=cfg.use_amp)
|
||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
log_output_dir(cfg.output_dir)
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.offline.steps=} ({format_big_number(cfg.offline.steps)})")
|
||||
logging.info(f"{cfg.online.steps=}")
|
||||
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
|
||||
logging.info(f"{offline_dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def evaluate_and_checkpoint_if_needed(step, is_online):
|
||||
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
|
||||
_num_digits = max(6, len(str(cfg.offline.steps + cfg.online.steps)))
|
||||
step_identifier = f"{step:0{_num_digits}d}"
|
||||
|
||||
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||
if cfg.env is not None and cfg.eval_freq > 0 and step % cfg.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
assert eval_env is not None
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_identifier}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
@@ -376,28 +264,27 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.training.save_checkpoint and (
|
||||
step % cfg.training.save_freq == 0
|
||||
or step == cfg.training.offline_steps + cfg.training.online_steps
|
||||
if cfg.save_checkpoint and (
|
||||
step % cfg.save_freq == 0 or step == cfg.offline.steps + cfg.online.steps
|
||||
):
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
logger.save_checkpoint(
|
||||
step,
|
||||
step_identifier,
|
||||
policy,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
identifier=step_identifier,
|
||||
)
|
||||
logging.info("Resume training")
|
||||
|
||||
# create dataloader for offline training
|
||||
if cfg.training.get("drop_n_last_frames"):
|
||||
if getattr(cfg.policy, "drop_n_last_frames", None):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
offline_dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.training.drop_n_last_frames,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
@@ -405,8 +292,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
sampler = None
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
offline_dataset,
|
||||
num_workers=cfg.training.num_workers,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
@@ -415,8 +302,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
|
||||
if hasattr(policy, "init_ema_modules"):
|
||||
policy.init_ema_modules()
|
||||
|
||||
offline_step = 0
|
||||
for _ in range(step, cfg.training.offline_steps):
|
||||
for _ in range(step, cfg.offline.steps):
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
@@ -425,13 +316,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.training.grad_clip_norm,
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
@@ -439,7 +331,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
@@ -449,7 +341,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
step += 1
|
||||
offline_step += 1 # noqa: SIM113
|
||||
|
||||
if cfg.training.online_steps == 0:
|
||||
if cfg.online.steps == 0:
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
@@ -458,8 +350,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# Online training.
|
||||
|
||||
# Create an env dedicated to online episodes collection from policy rollout.
|
||||
online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
||||
resolve_delta_timestamps(cfg)
|
||||
online_env = make_env(cfg.env, n_envs=cfg.online.rollout_batch_size)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, offline_dataset.meta)
|
||||
online_buffer_path = logger.log_dir / "online_buffer"
|
||||
if cfg.resume and not online_buffer_path.exists():
|
||||
# If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
|
||||
@@ -473,31 +365,41 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
online_dataset = OnlineBuffer(
|
||||
online_buffer_path,
|
||||
data_spec={
|
||||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.input_shapes.items()},
|
||||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
|
||||
**{
|
||||
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
|
||||
for key, ft in policy.config.input_features.items()
|
||||
},
|
||||
**{
|
||||
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
|
||||
for key, ft in policy.config.output_features.items()
|
||||
},
|
||||
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
||||
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
||||
"task_index": {"shape": (), "dtype": np.dtype("int64")},
|
||||
# FIXME: 'task' is a string
|
||||
# "task": {"shape": (), "dtype": np.dtype("?")},
|
||||
# FIXME: 'next.success' is expected by pusht env but not xarm
|
||||
"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||
},
|
||||
buffer_capacity=cfg.training.online_buffer_capacity,
|
||||
buffer_capacity=cfg.online.buffer_capacity,
|
||||
fps=online_env.unwrapped.metadata["render_fps"],
|
||||
delta_timestamps=cfg.training.delta_timestamps,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
|
||||
# makes it possible to do online rollouts in parallel with training updates).
|
||||
online_rollout_policy = deepcopy(policy) if cfg.training.do_online_rollout_async else policy
|
||||
online_rollout_policy = deepcopy(policy) if cfg.online.do_rollout_async else policy
|
||||
|
||||
# Create dataloader for online training.
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
sampler_weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.online.sampling_ratio,
|
||||
)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(
|
||||
sampler_weights,
|
||||
@@ -506,20 +408,22 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
# Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled,
|
||||
# these are still used but effectively do nothing.
|
||||
lock = Lock()
|
||||
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
|
||||
# parallelization of rollouts is handled within the job.
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
if cfg.online.do_rollout_async:
|
||||
# Lock and thread pool executor for asynchronous online rollouts.
|
||||
lock = Lock()
|
||||
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
|
||||
# parallelization of rollouts is handled within the job.
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
else:
|
||||
lock = None
|
||||
|
||||
online_step = 0
|
||||
online_rollout_s = 0 # time take to do online rollout
|
||||
@@ -527,10 +431,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
|
||||
# online rollout option.
|
||||
await_update_online_buffer_s = 0
|
||||
rollout_start_seed = cfg.training.online_env_seed
|
||||
rollout_start_seed = cfg.online.env_seed
|
||||
|
||||
while True:
|
||||
if online_step == cfg.training.online_steps:
|
||||
if online_step == cfg.online.steps:
|
||||
break
|
||||
|
||||
if online_step == 0:
|
||||
@@ -538,25 +442,34 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
def sample_trajectory_and_update_buffer():
|
||||
nonlocal rollout_start_seed
|
||||
with lock:
|
||||
|
||||
with lock if lock is not None else nullcontext():
|
||||
online_rollout_policy.load_state_dict(policy.state_dict())
|
||||
|
||||
online_rollout_policy.eval()
|
||||
start_rollout_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
online_env,
|
||||
online_rollout_policy,
|
||||
n_episodes=cfg.training.online_rollout_n_episodes,
|
||||
max_episodes_rendered=min(10, cfg.training.online_rollout_n_episodes),
|
||||
n_episodes=cfg.online.rollout_n_episodes,
|
||||
max_episodes_rendered=min(10, cfg.online.rollout_n_episodes),
|
||||
videos_dir=logger.log_dir / "online_rollout_videos",
|
||||
return_episode_data=True,
|
||||
start_seed=(
|
||||
rollout_start_seed := (rollout_start_seed + cfg.training.batch_size) % 1000000
|
||||
),
|
||||
start_seed=(rollout_start_seed := (rollout_start_seed + cfg.batch_size) % 1000000),
|
||||
)
|
||||
online_rollout_s = time.perf_counter() - start_rollout_time
|
||||
|
||||
with lock:
|
||||
if len(offline_dataset.meta.tasks) > 1:
|
||||
raise NotImplementedError("Add support for multi task.")
|
||||
|
||||
# TODO(rcadene, aliberts): Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
|
||||
total_num_frames = eval_info["episodes"]["index"].shape[0]
|
||||
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)
|
||||
eval_info["episodes"]["task"] = ["do the thing"] * total_num_frames
|
||||
|
||||
with lock if lock is not None else nullcontext():
|
||||
start_update_buffer_time = time.perf_counter()
|
||||
online_dataset.add_data(eval_info["episodes"])
|
||||
|
||||
@@ -566,12 +479,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# Update the sampling weights.
|
||||
sampler.weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.online.sampling_ratio,
|
||||
)
|
||||
sampler.num_frames = len(concat_dataset)
|
||||
|
||||
@@ -579,36 +492,36 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
return online_rollout_s, update_online_buffer_s
|
||||
|
||||
future = executor.submit(sample_trajectory_and_update_buffer)
|
||||
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
|
||||
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
|
||||
if (
|
||||
not cfg.training.do_online_rollout_async
|
||||
or len(online_dataset) <= cfg.training.online_buffer_seed_size
|
||||
):
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
if lock is None:
|
||||
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
|
||||
else:
|
||||
future = executor.submit(sample_trajectory_and_update_buffer)
|
||||
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
|
||||
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
|
||||
if len(online_dataset) <= cfg.online.buffer_seed_size:
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
|
||||
if len(online_dataset) <= cfg.training.online_buffer_seed_size:
|
||||
logging.info(
|
||||
f"Seeding online buffer: {len(online_dataset)}/{cfg.training.online_buffer_seed_size}"
|
||||
)
|
||||
if len(online_dataset) <= cfg.online.buffer_seed_size:
|
||||
logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}")
|
||||
continue
|
||||
|
||||
policy.train()
|
||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
||||
with lock:
|
||||
for _ in range(cfg.online.steps_between_rollouts):
|
||||
with lock if lock is not None else nullcontext():
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
dtype = get_safe_dtype(batch[key].dtype, device)
|
||||
batch[key] = batch[key].to(device=device, dtype=dtype, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.training.grad_clip_norm,
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
@@ -619,10 +532,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
train_info["online_rollout_s"] = online_rollout_s
|
||||
train_info["update_online_buffer_s"] = update_online_buffer_s
|
||||
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
|
||||
with lock:
|
||||
with lock if lock is not None else nullcontext():
|
||||
train_info["online_buffer_size"] = len(online_dataset)
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
@@ -634,12 +547,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
|
||||
# to do the next batch of rollouts.
|
||||
if future.running():
|
||||
if cfg.online.do_rollout_async and future.running():
|
||||
start = time.perf_counter()
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
await_update_online_buffer_s = time.perf_counter() - start
|
||||
|
||||
if online_step >= cfg.training.online_steps:
|
||||
if online_step >= cfg.online.steps:
|
||||
break
|
||||
|
||||
if eval_env:
|
||||
@@ -647,23 +560,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
init_logging()
|
||||
train()
|
||||
|
||||
@@ -15,14 +15,14 @@
|
||||
# limitations under the License.
|
||||
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
||||
|
||||
Note: The last frame of the episode doesnt always correspond to a final state.
|
||||
Note: The last frame of the episode doesn't always correspond to a final state.
|
||||
That's because our datasets are composed of transition from state to state up to
|
||||
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
||||
However, there might not be a transition from a final state to another state.
|
||||
|
||||
Note: This script aims to visualize the data used to train the neural networks.
|
||||
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
||||
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
|
||||
lossy compression artifacts since these images have been decoded from compressed mp4 videos to
|
||||
save disk space. The compression factor applied has been tuned to not affect success rate.
|
||||
|
||||
Examples:
|
||||
@@ -199,7 +199,7 @@ def main():
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
|
||||
help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episode-index",
|
||||
|
||||
@@ -18,142 +18,102 @@
|
||||
This script will generate examples of transformed images as they are output by LeRobot dataset.
|
||||
Additionally, each individual transform can be visualized separately as well as examples of combined transforms
|
||||
|
||||
|
||||
--- Usage Examples ---
|
||||
|
||||
Increase hue jitter
|
||||
```
|
||||
Example:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.hue.min_max="[-0.25,0.25]"
|
||||
--repo_id=lerobot/pusht \
|
||||
--episodes='[0]' \
|
||||
--image_transforms.enable=True
|
||||
```
|
||||
|
||||
Increase brightness & brightness weight
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.brightness.weight=10.0 \
|
||||
training.image_transforms.brightness.min_max="[1.0,2.0]"
|
||||
```
|
||||
|
||||
Blur images and disable saturation & hue
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.sharpness.weight=10.0 \
|
||||
training.image_transforms.sharpness.min_max="[0.0,1.0]" \
|
||||
training.image_transforms.saturation.weight=0.0 \
|
||||
training.image_transforms.hue.weight=0.0
|
||||
```
|
||||
|
||||
Use all transforms with random order
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.max_num_transforms=5 \
|
||||
training.image_transforms.random_order=true
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from dataclasses import replace
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import draccus
|
||||
from torchvision.transforms import ToPILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
from lerobot.common.datasets.transforms import (
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
|
||||
OUTPUT_DIR = Path("outputs/image_transforms")
|
||||
to_pil = ToPILImage()
|
||||
|
||||
|
||||
def save_config_all_transforms(cfg, original_frame, output_dir, n_examples):
|
||||
tf = get_image_transforms(
|
||||
brightness_weight=cfg.brightness.weight,
|
||||
brightness_min_max=cfg.brightness.min_max,
|
||||
contrast_weight=cfg.contrast.weight,
|
||||
contrast_min_max=cfg.contrast.min_max,
|
||||
saturation_weight=cfg.saturation.weight,
|
||||
saturation_min_max=cfg.saturation.min_max,
|
||||
hue_weight=cfg.hue.weight,
|
||||
hue_min_max=cfg.hue.min_max,
|
||||
sharpness_weight=cfg.sharpness.weight,
|
||||
sharpness_min_max=cfg.sharpness.min_max,
|
||||
max_num_transforms=cfg.max_num_transforms,
|
||||
random_order=cfg.random_order,
|
||||
)
|
||||
|
||||
def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
|
||||
output_dir_all = output_dir / "all"
|
||||
output_dir_all.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tfs = ImageTransforms(cfg)
|
||||
for i in range(1, n_examples + 1):
|
||||
transformed_frame = tf(original_frame)
|
||||
transformed_frame = tfs(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
|
||||
|
||||
print("Combined transforms examples saved to:")
|
||||
print(f" {output_dir_all}")
|
||||
|
||||
|
||||
def save_config_single_transforms(cfg, original_frame, output_dir, n_examples):
|
||||
transforms = [
|
||||
"brightness",
|
||||
"contrast",
|
||||
"saturation",
|
||||
"hue",
|
||||
"sharpness",
|
||||
]
|
||||
def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
|
||||
if not cfg.enable:
|
||||
logging.warning(
|
||||
"No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
|
||||
)
|
||||
return
|
||||
|
||||
print("Individual transforms examples saved to:")
|
||||
for transform in transforms:
|
||||
# Apply one transformation with random value in min_max range
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": cfg[f"{transform}"].min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
output_dir_single = output_dir / f"{transform}"
|
||||
for tf_name, tf_cfg in cfg.tfs.items():
|
||||
# Apply a few transformation with random value in min_max range
|
||||
output_dir_single = output_dir / tf_name
|
||||
output_dir_single.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
for i in range(1, n_examples + 1):
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
|
||||
|
||||
# Apply min transformation
|
||||
min_value, max_value = cfg[f"{transform}"].min_max
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": (min_value, min_value),
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / "min.png", quality=100)
|
||||
# Apply min, max, average transformations
|
||||
tf_cfg_kwgs_min = deepcopy(tf_cfg.kwargs)
|
||||
tf_cfg_kwgs_max = deepcopy(tf_cfg.kwargs)
|
||||
tf_cfg_kwgs_avg = deepcopy(tf_cfg.kwargs)
|
||||
|
||||
# Apply max transformation
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": (max_value, max_value),
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / "max.png", quality=100)
|
||||
for key, (min_, max_) in tf_cfg.kwargs.items():
|
||||
avg = (min_ + max_) / 2
|
||||
tf_cfg_kwgs_min[key] = [min_, min_]
|
||||
tf_cfg_kwgs_max[key] = [max_, max_]
|
||||
tf_cfg_kwgs_avg[key] = [avg, avg]
|
||||
|
||||
# Apply mean transformation
|
||||
mean_value = (min_value + max_value) / 2
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": (mean_value, mean_value),
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / "mean.png", quality=100)
|
||||
tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
|
||||
tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
|
||||
tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
|
||||
|
||||
tf_frame_min = tf_min(original_frame)
|
||||
tf_frame_max = tf_max(original_frame)
|
||||
tf_frame_avg = tf_avg(original_frame)
|
||||
|
||||
to_pil(tf_frame_min).save(output_dir_single / "min.png", quality=100)
|
||||
to_pil(tf_frame_max).save(output_dir_single / "max.png", quality=100)
|
||||
to_pil(tf_frame_avg).save(output_dir_single / "mean.png", quality=100)
|
||||
|
||||
print(f" {output_dir_single}")
|
||||
|
||||
|
||||
def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
@draccus.wrap()
|
||||
def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.repo_id,
|
||||
episodes=cfg.episodes,
|
||||
local_files_only=cfg.local_files_only,
|
||||
video_backend=cfg.video_backend,
|
||||
)
|
||||
|
||||
output_dir = output_dir / cfg.dataset_repo_id.split("/")[-1]
|
||||
output_dir = output_dir / cfg.repo_id.split("/")[-1]
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get 1st frame from 1st camera of 1st episode
|
||||
@@ -162,14 +122,9 @@ def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
|
||||
print("\nOriginal frame saved to:")
|
||||
print(f" {output_dir / 'original_frame.png'}.")
|
||||
|
||||
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
|
||||
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def visualize_transforms_cli(cfg):
|
||||
visualize_transforms(cfg, output_dir=OUTPUT_DIR)
|
||||
save_all_transforms(cfg.image_transforms, original_frame, output_dir, n_examples)
|
||||
save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
visualize_transforms_cli()
|
||||
visualize_image_transforms()
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
|
||||
</head>
|
||||
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
|
||||
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
|
||||
inputValue: '',
|
||||
navigateToDataset() {
|
||||
const trimmedValue = this.inputValue.trim();
|
||||
@@ -40,14 +40,14 @@
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex w-full max-w-lg px-4 mb-4">
|
||||
<input
|
||||
type="text"
|
||||
<input
|
||||
type="text"
|
||||
x-model="inputValue"
|
||||
@keyup.enter="navigateToDataset"
|
||||
placeholder="enter dataset id (ex: lerobot/droid_100)"
|
||||
class="flex-grow px-4 py-2 rounded-l bg-white bg-opacity-20 text-white placeholder-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
<button
|
||||
<button
|
||||
@click="navigateToDataset"
|
||||
class="px-4 py-2 bg-blue-500 text-white rounded-r hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
@@ -65,4 +65,4 @@
|
||||
</details>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
</html>
|
||||
|
||||
@@ -107,8 +107,8 @@
|
||||
<span class="truncate">filter videos</span>
|
||||
<div class="transition-transform" :class="{ 'rotate-180': isVideosDropdownOpen }">🔽</div>
|
||||
</div>
|
||||
|
||||
<div x-show="isVideosDropdownOpen"
|
||||
|
||||
<div x-show="isVideosDropdownOpen"
|
||||
class="absolute mt-1 border border-slate-500 rounded shadow-lg z-10">
|
||||
<div>
|
||||
<template x-for="option in videosKeys" :key="option">
|
||||
|
||||
Reference in New Issue
Block a user