Remove update method from the policy (#99)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Quentin Gallouédec
2024-04-29 12:27:58 +02:00
committed by GitHub
parent 5b4fd8891d
commit 508bd92d03
8 changed files with 84 additions and 122 deletions

View File

@@ -18,8 +18,8 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
("xarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]),
# ("xarm", "tdmpc", ["policy.mpc=true"]),
# ("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]),
(
@@ -86,7 +86,7 @@ def test_policy(env_name, policy_name, extra_overrides):
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy
policy.update(batch, step=0)
policy.forward(batch, step=0)
# reset the policy and environment
policy.reset()