Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -26,13 +26,13 @@ from pathlib import Path
from textwrap import dedent
from lerobot import available_datasets
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
LOCAL_DIR = Path("data/")
ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml")
ALOHA_MOBILE_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG),
"robot_config": AlohaRobotConfig(),
"license": "mit",
"url": "https://mobile-aloha.github.io/",
"paper": "https://arxiv.org/abs/2401.02117",
@@ -45,7 +45,7 @@ ALOHA_MOBILE_INFO = {
}""").lstrip(),
}
ALOHA_STATIC_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG),
"robot_config": AlohaRobotConfig(),
"license": "mit",
"url": "https://tonyzhaozh.github.io/aloha/",
"paper": "https://arxiv.org/abs/2304.13705",

View File

@@ -141,7 +141,8 @@ from lerobot.common.datasets.video_utils import (
get_image_pixel_channels,
get_video_info,
)
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_config
V16 = "v1.6"
V20 = "v2.0"
@@ -152,19 +153,18 @@ V1_INFO_PATH = "meta_data/info.json"
V1_STATS_PATH = "meta_data/stats.safetensors"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
robot_cfg = init_hydra_config(config_path, config_overrides)
if robot_cfg["robot_type"] in ["aloha", "koch"]:
def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
if robot_cfg.type in ["aloha", "koch"]:
state_names = [
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor
for arm in robot_cfg["follower_arms"]
for motor in robot_cfg["follower_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
for arm in robot_cfg.follower_arms
for motor in robot_cfg.follower_arms[arm].motors
]
action_names = [
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor
for arm in robot_cfg["leader_arms"]
for motor in robot_cfg["leader_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
for arm in robot_cfg.leader_arms
for motor in robot_cfg.leader_arms[arm].motors
]
# elif robot_cfg["robot_type"] == "stretch3": TODO
else:
@@ -173,7 +173,7 @@ def parse_robot_config(config_path: Path, config_overrides: list[str] | None = N
)
return {
"robot_type": robot_cfg["robot_type"],
"robot_type": robot_cfg.type,
"names": {
"observation.state": state_names,
"observation.effort": state_names,
@@ -203,7 +203,10 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
torch.testing.assert_close(stats_json[key], stats[key])
def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]:
def get_features_from_hf_dataset(
dataset: Dataset, robot_config: RobotConfig | None = None
) -> dict[str, list]:
robot_config = parse_robot_config(robot_config)
features = {}
for key, ft in dataset.features.items():
if isinstance(ft, datasets.Value):
@@ -224,11 +227,11 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
image = dataset[0][key] # Assuming first row
channels = get_image_pixel_channels(image)
shape = (image.height, image.width, channels)
names = ["height", "width", "channel"]
names = ["height", "width", "channels"]
elif ft._type == "VideoFrame":
dtype = "video"
shape = None # Add shape later
names = ["height", "width", "channel"]
names = ["height", "width", "channels"]
features[key] = {
"dtype": dtype,
@@ -436,7 +439,7 @@ def convert_dataset(
single_task: str | None = None,
tasks_path: Path | None = None,
tasks_col: Path | None = None,
robot_config: dict | None = None,
robot_config: RobotConfig | None = None,
test_branch: str | None = None,
**card_kwargs,
):
@@ -532,7 +535,7 @@ def convert_dataset(
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
if robot_config is not None:
robot_type = robot_config["robot_type"]
robot_type = robot_config.type
repo_tags = [robot_type]
else:
robot_type = "unknown"
@@ -621,16 +624,10 @@ def main():
help="The path to a .json file containing one language instruction for each episode_index",
)
parser.add_argument(
"--robot-config",
type=Path,
default=None,
help="Path to the robot's config yaml the dataset during conversion.",
)
parser.add_argument(
"--robot-overrides",
"--robot",
type=str,
nargs="*",
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)",
default=None,
help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
)
parser.add_argument(
"--local-dir",
@@ -655,8 +652,10 @@ def main():
if not args.local_dir:
args.local_dir = Path("/tmp/lerobot_dataset_v2")
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
del args.robot_config, args.robot_overrides
if args.robot is not None:
robot_config = make_robot_config(args.robot)
del args.robot
convert_dataset(**vars(args), robot_config=robot_config)