Enable tests for TD-MPC (#160)

This commit is contained in:
Alexander Soare
2024-05-09 13:42:12 +01:00
committed by GitHub
parent 7bb5b15f4c
commit e89521dfa0
7 changed files with 5 additions and 6 deletions

View File

@@ -22,9 +22,8 @@ test-end-to-end:
${MAKE} test-act-ete-eval ${MAKE} test-act-ete-eval
${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval ${MAKE} test-diffusion-ete-eval
# TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc ${MAKE} test-tdmpc-ete-train
# ${MAKE} test-tdmpc-ete-train ${MAKE} test-tdmpc-ete-eval
# ${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval ${MAKE} test-default-ete-eval
test-act-ete-train: test-act-ete-train:
@@ -80,7 +79,7 @@ test-tdmpc-ete-train:
policy=tdmpc \ policy=tdmpc \
env=xarm \ env=xarm \
env.task=XarmLift-v0 \ env.task=XarmLift-v0 \
dataset_repo_id=lerobot/xarm_lift_medium_replay \ dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
training.online_steps=2 \ training.online_steps=2 \

View File

@@ -1,7 +1,7 @@
# @package _global_ # @package _global_
seed: 1 seed: 1
dataset_repo_id: lerobot/xarm_lift_medium_replay dataset_repo_id: lerobot/xarm_lift_medium
training: training:
offline_steps: 25000 offline_steps: 25000

View File

@@ -236,7 +236,7 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name, policy_name, extra_overrides", "env_name, policy_name, extra_overrides",
[ [
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), ("xarm", "tdmpc", []),
( (
"pusht", "pusht",
"diffusion", "diffusion",