Tidy up yaml configs (#121)
This commit is contained in:
@@ -33,7 +33,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"dataset.repo_id={repo_id}",
|
||||
f"dataset_repo_id={repo_id}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
],
|
||||
|
||||
@@ -39,7 +39,7 @@ def test_examples_3_and_2():
|
||||
("training_steps = 5000", "training_steps = 1"),
|
||||
("num_workers=4", "num_workers=0"),
|
||||
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
||||
("batch_size=cfg.batch_size", "batch_size=1"),
|
||||
("batch_size=64", "batch_size=1"),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -58,8 +58,8 @@ def test_examples_3_and_2():
|
||||
file_contents = _find_and_replace(
|
||||
file_contents,
|
||||
[
|
||||
('"eval_episodes=10"', '"eval_episodes=1"'),
|
||||
('"rollout_batch_size=10"', '"rollout_batch_size=1"'),
|
||||
('"eval.n_episodes=10"', '"eval.n_episodes=1"'),
|
||||
('"eval.batch_size=10"', '"eval.batch_size=1"'),
|
||||
('"device=cuda"', '"device=cpu"'),
|
||||
(
|
||||
'# folder = Path("outputs/train/example_pusht_diffusion")',
|
||||
|
||||
@@ -21,21 +21,21 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
# ("xarm", "tdmpc", ["policy.mpc=true"]),
|
||||
# ("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||
("pusht", "diffusion", []),
|
||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]),
|
||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_scripted"],
|
||||
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_scripted"],
|
||||
),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_human"],
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_human"],
|
||||
),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ def test_visualize_dataset(tmpdir, repo_id):
|
||||
overrides=[
|
||||
"policy=act",
|
||||
"env=aloha",
|
||||
f"dataset.repo_id={repo_id}",
|
||||
f"dataset_repo_id={repo_id}",
|
||||
],
|
||||
)
|
||||
video_paths = visualize_dataset(cfg, out_dir=tmpdir)
|
||||
|
||||
Reference in New Issue
Block a user