Follow transformers single file naming conventions (#124)

This commit is contained in:
Alexander Soare
2024-05-01 13:09:42 +01:00
committed by GitHub
parent 986583dc5c
commit 01d5490d44
6 changed files with 62 additions and 58 deletions

View File

@@ -4,7 +4,7 @@ import gymnasium as gym
import pytest
import lerobot
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
from tests.utils import require_env
@@ -30,7 +30,7 @@ def test_available_policies():
consistent with those listed in `lerobot/__init__.py`.
"""
policy_classes = [
ActionChunkingTransformerPolicy,
ACTPolicy,
DiffusionPolicy,
TDMPCPolicy,
]

View File

@@ -5,7 +5,7 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.normalize import Normalize, Unnormalize
@@ -115,7 +115,7 @@ def test_policy(env_name, policy_name, extra_overrides):
new_policy.load_state_dict(policy.state_dict())
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ActionChunkingTransformerPolicy])
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ACTPolicy])
def test_policy_defaults(policy_cls):
kwargs = {}
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.