forked from tangger/lerobot
Make policies compatible with other/multiple image keys (#149)
This commit is contained in:
@@ -64,6 +64,14 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
(
|
||||
"aloha",
|
||||
"diffusion",
|
||||
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||
],
|
||||
)
|
||||
@require_env
|
||||
@@ -87,6 +95,31 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
+ extra_overrides,
|
||||
)
|
||||
|
||||
# Additional config override logic.
|
||||
if env_name == "aloha" and policy_name == "diffusion":
|
||||
for keys in [
|
||||
("training", "delta_timestamps"),
|
||||
("policy", "input_shapes"),
|
||||
("policy", "input_normalization_modes"),
|
||||
]:
|
||||
dct = dict(cfg[keys[0]][keys[1]])
|
||||
dct["observation.images.top"] = dct["observation.image"]
|
||||
del dct["observation.image"]
|
||||
cfg[keys[0]][keys[1]] = dct
|
||||
cfg.override_dataset_stats = None
|
||||
|
||||
# Additional config override logic.
|
||||
if env_name == "pusht" and policy_name == "act":
|
||||
for keys in [
|
||||
("policy", "input_shapes"),
|
||||
("policy", "input_normalization_modes"),
|
||||
]:
|
||||
dct = dict(cfg[keys[0]][keys[1]])
|
||||
dct["observation.image"] = dct["observation.images.top"]
|
||||
del dct["observation.images.top"]
|
||||
cfg[keys[0]][keys[1]] = dct
|
||||
cfg.override_dataset_stats = None
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
|
||||
Reference in New Issue
Block a user