Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -43,6 +43,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
repo_dir.mkdir(parents=True, exist_ok=True)
dataset = LeRobotDataset(
repo_id=repo_id,
episodes=[0],
)
# save 2 first frames of first episode

View File

@@ -19,29 +19,21 @@ import torch
from safetensors.torch import save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID
from tests.utils import DEFAULT_CONFIG_PATH
from lerobot.common.datasets.transforms import (
ImageTransformConfig,
ImageTransforms,
ImageTransformsConfig,
make_transform_from_config,
)
from lerobot.common.utils.utils import seeded_context
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
contrast_weight=cfg_tf.contrast.weight,
contrast_min_max=cfg_tf.contrast.min_max,
saturation_weight=cfg_tf.saturation.weight,
saturation_min_max=cfg_tf.saturation.min_max,
hue_weight=cfg_tf.hue.weight,
hue_min_max=cfg_tf.hue.min_max,
sharpness_weight=cfg_tf.sharpness.weight,
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
)
cfg = ImageTransformsConfig(enable=True)
default_tf = ImageTransforms(cfg)
with seeded_context(1337):
img_tf = default_tf(original_frame)
@@ -51,29 +43,26 @@ def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
transforms = {
"brightness": [(0.5, 0.5), (2.0, 2.0)],
"contrast": [(0.5, 0.5), (2.0, 2.0)],
"saturation": [(0.5, 0.5), (2.0, 2.0)],
"hue": [(-0.25, -0.25), (0.25, 0.25)],
"sharpness": [(0.5, 0.5), (2.0, 2.0)],
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
}
frames = {"original_frame": original_frame}
for transform, values in transforms.items():
for min_max in values:
kwargs = {
f"{transform}_weight": 1.0,
f"{transform}_min_max": min_max,
}
tf = get_image_transforms(**kwargs)
key = f"{transform}_{min_max[0]}_{min_max[1]}"
for tf_type, tf_name, min_max_values in transforms.items():
for min_max in min_max_values:
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
tf = make_transform_from_config(tf_cfg)
key = f"{tf_name}_{min_max[0]}_{min_max[1]}"
frames[key] = tf(original_frame)
save_file(frames, output_dir / "single_transforms.safetensors")
def main():
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
dataset = LeRobotDataset(DATASET_REPO_ID, episodes=[0], image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
original_frame = dataset[0][dataset.meta.camera_keys[0]]

View File

@@ -20,32 +20,35 @@ import torch
from safetensors.torch import save_file
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.utils import DEFAULT_CONFIG_PATH
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy, make_policy_config
from lerobot.common.utils.utils import set_global_seed
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
def get_policy_stats(env_name, policy_name, extra_overrides):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"policy={policy_name}",
"device=cpu",
]
+ extra_overrides,
)
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs):
# TODO(rcadene, aliberts): env_name?
set_global_seed(1337)
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs),
device="cpu",
**train_kwargs,
)
train_cfg.validate() # Needed for auto-setting some parameters
dataset = make_dataset(train_cfg)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=train_cfg.device)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=cfg.training.batch_size,
batch_size=train_cfg.batch_size,
shuffle=False,
)
@@ -72,24 +75,28 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
optimizer.zero_grad()
policy.reset()
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
dataset.delta_timestamps = None
# HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension
# We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors
# indicating padding (those ending with "_is_pad")
dataset.delta_indices = None
batch = next(iter(dataloader))
obs = {}
for k in batch:
if k.endswith("_is_pad"):
continue
if k.startswith("observation"):
obs[k] = batch[k]
if "n_action_steps" in cfg.policy:
actions_queue = cfg.policy.n_action_steps
if hasattr(train_cfg.policy, "n_action_steps"):
actions_queue = train_cfg.policy.n_action_steps
else:
actions_queue = cfg.policy.n_action_repeats
actions_queue = train_cfg.policy.n_action_repeats
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra):
def save_policy_to_safetensors(output_dir, env_name, policy_name, policy_kwargs, file_name_extra):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
if env_policy_dir.exists():
@@ -99,7 +106,7 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
shutil.rmtree(env_policy_dir)
env_policy_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, policy_kwargs)
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
@@ -108,26 +115,31 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
if __name__ == "__main__":
env_policies = [
# ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
# (
# "pusht",
# "diffusion",
# [
# "policy.n_action_steps=8",
# "policy.num_inference_steps=10",
# "policy.down_dims=[128, 256, 512]",
# ],
# "",
# ),
# ("aloha", "act", ["policy.n_action_steps=10"], ""),
# ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, "use_mpc"),
(
"lerobot/pusht",
"pusht",
"diffusion",
{
"n_action_steps": 8,
"num_inference_steps": 10,
"down_dims": [128, 256, 512],
},
"",
),
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, ""),
(
"lerobot/aloha_sim_insertion_human",
"aloha",
"act",
{"n_action_steps": 1000, "chunk_size": 1000},
"_1000_steps",
),
]
if len(env_policies) == 0:
raise RuntimeError("No policies were provided!")
for env, policy, extra_overrides, file_name_extra in env_policies:
for ds_repo_id, env, policy, policy_kwargs, file_name_extra in env_policies:
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
"tests/data/save_policy_to_safetensors", ds_repo_id, env, policy, policy_kwargs, file_name_extra
)