WIP: add multiprocess

This commit is contained in:
Remi Cadene
2024-09-28 15:00:38 +02:00
parent 9b76ee9eb0
commit 77ba43d25b
2 changed files with 66 additions and 25 deletions

View File

@@ -92,12 +92,6 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
if mock:
request.getfixturevalue("patch_builtins_input")
if robot_type == "aloha":
pytest.skip("TODO(rcadene): enable test once aloha_real and act_aloha_real are merged")
env_name = "koch_real"
policy_name = "act_koch_real"
root = Path(tmpdir)
repo_id = "lerobot/debug"
@@ -117,13 +111,28 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
# TODO(rcadene, aliberts): rethink this design
if robot_type == "aloha":
env_name = "aloha_real"
policy_name = "act_aloha_real"
elif robot_type in ["koch", "koch_bimanual"]:
env_name = "koch_real"
policy_name = "act_koch_real"
else:
raise NotImplementedError(robot_type)
overrides = [
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
]
if robot_type == "koch_bimanual":
overrides += ["env.state_dim=12", "env.action_dim=12"]
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
],
overrides=overrides,
)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)