Remove update method from the policy (#99)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5b4fd8891d
commit
508bd92d03
@@ -18,8 +18,8 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
("xarm", "tdmpc", ["policy.mpc=true"]),
|
||||
("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||
# ("xarm", "tdmpc", ["policy.mpc=true"]),
|
||||
# ("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||
("pusht", "diffusion", []),
|
||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]),
|
||||
(
|
||||
@@ -86,7 +86,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
# Test updating the policy
|
||||
policy.update(batch, step=0)
|
||||
policy.forward(batch, step=0)
|
||||
|
||||
# reset the policy and environment
|
||||
policy.reset()
|
||||
|
||||
Reference in New Issue
Block a user