Simplify configs (#550)
Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
@@ -43,6 +43,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
repo_dir.mkdir(parents=True, exist_ok=True)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
episodes=[0],
|
||||
)
|
||||
|
||||
# save 2 first frames of first episode
|
||||
|
||||
@@ -19,29 +19,21 @@ import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
from lerobot.common.datasets.transforms import (
|
||||
ImageTransformConfig,
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
|
||||
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
|
||||
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
cfg = ImageTransformsConfig(enable=True)
|
||||
default_tf = ImageTransforms(cfg)
|
||||
|
||||
with seeded_context(1337):
|
||||
img_tf = default_tf(original_frame)
|
||||
@@ -51,29 +43,26 @@ def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path
|
||||
|
||||
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
||||
transforms = {
|
||||
"brightness": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"contrast": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"saturation": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"hue": [(-0.25, -0.25), (0.25, 0.25)],
|
||||
"sharpness": [(0.5, 0.5), (2.0, 2.0)],
|
||||
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
||||
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
}
|
||||
|
||||
frames = {"original_frame": original_frame}
|
||||
for transform, values in transforms.items():
|
||||
for min_max in values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||
for tf_type, tf_name, min_max_values in transforms.items():
|
||||
for min_max in min_max_values:
|
||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
key = f"{tf_name}_{min_max[0]}_{min_max[1]}"
|
||||
frames[key] = tf(original_frame)
|
||||
|
||||
save_file(frames, output_dir / "single_transforms.safetensors")
|
||||
|
||||
|
||||
def main():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID, episodes=[0], image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
original_frame = dataset[0][dataset.meta.camera_keys[0]]
|
||||
|
||||
@@ -20,32 +20,35 @@ import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.common.utils.utils import set_global_seed
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
"device=cpu",
|
||||
]
|
||||
+ extra_overrides,
|
||||
)
|
||||
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs):
|
||||
# TODO(rcadene, aliberts): env_name?
|
||||
set_global_seed(1337)
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.meta.stats)
|
||||
policy.train()
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||
device="cpu",
|
||||
**train_kwargs,
|
||||
)
|
||||
train_cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
dataset = make_dataset(train_cfg)
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=train_cfg.device)
|
||||
policy.train()
|
||||
|
||||
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=cfg.training.batch_size,
|
||||
batch_size=train_cfg.batch_size,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
@@ -72,24 +75,28 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
optimizer.zero_grad()
|
||||
policy.reset()
|
||||
|
||||
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||
dataset.delta_timestamps = None
|
||||
# HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension
|
||||
# We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors
|
||||
# indicating padding (those ending with "_is_pad")
|
||||
dataset.delta_indices = None
|
||||
batch = next(iter(dataloader))
|
||||
obs = {}
|
||||
for k in batch:
|
||||
if k.endswith("_is_pad"):
|
||||
continue
|
||||
if k.startswith("observation"):
|
||||
obs[k] = batch[k]
|
||||
|
||||
if "n_action_steps" in cfg.policy:
|
||||
actions_queue = cfg.policy.n_action_steps
|
||||
if hasattr(train_cfg.policy, "n_action_steps"):
|
||||
actions_queue = train_cfg.policy.n_action_steps
|
||||
else:
|
||||
actions_queue = cfg.policy.n_action_repeats
|
||||
actions_queue = train_cfg.policy.n_action_repeats
|
||||
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra):
|
||||
def save_policy_to_safetensors(output_dir, env_name, policy_name, policy_kwargs, file_name_extra):
|
||||
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
|
||||
|
||||
if env_policy_dir.exists():
|
||||
@@ -99,7 +106,7 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||
shutil.rmtree(env_policy_dir)
|
||||
|
||||
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, policy_kwargs)
|
||||
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
||||
@@ -108,26 +115,31 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_policies = [
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||
# (
|
||||
# "pusht",
|
||||
# "diffusion",
|
||||
# [
|
||||
# "policy.n_action_steps=8",
|
||||
# "policy.num_inference_steps=10",
|
||||
# "policy.down_dims=[128, 256, 512]",
|
||||
# ],
|
||||
# "",
|
||||
# ),
|
||||
# ("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
# ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, "use_policy"),
|
||||
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
(
|
||||
"lerobot/pusht",
|
||||
"pusht",
|
||||
"diffusion",
|
||||
{
|
||||
"n_action_steps": 8,
|
||||
"num_inference_steps": 10,
|
||||
"down_dims": [128, 256, 512],
|
||||
},
|
||||
"",
|
||||
),
|
||||
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, ""),
|
||||
(
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"aloha",
|
||||
"act",
|
||||
{"n_action_steps": 1000, "chunk_size": 1000},
|
||||
"_1000_steps",
|
||||
),
|
||||
]
|
||||
if len(env_policies) == 0:
|
||||
raise RuntimeError("No policies were provided!")
|
||||
for env, policy, extra_overrides, file_name_extra in env_policies:
|
||||
for ds_repo_id, env, policy, policy_kwargs, file_name_extra in env_policies:
|
||||
save_policy_to_safetensors(
|
||||
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
||||
"tests/data/save_policy_to_safetensors", ds_repo_id, env, policy, policy_kwargs, file_name_extra
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user