Follow transformers single file naming conventions (#124)
This commit is contained in:
@@ -4,7 +4,7 @@ import gymnasium as gym
|
||||
import pytest
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
from tests.utils import require_env
|
||||
@@ -30,7 +30,7 @@ def test_available_policies():
|
||||
consistent with those listed in `lerobot/__init__.py`.
|
||||
"""
|
||||
policy_classes = [
|
||||
ActionChunkingTransformerPolicy,
|
||||
ACTPolicy,
|
||||
DiffusionPolicy,
|
||||
TDMPCPolicy,
|
||||
]
|
||||
|
||||
@@ -5,7 +5,7 @@ from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
@@ -115,7 +115,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
new_policy.load_state_dict(policy.state_dict())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ActionChunkingTransformerPolicy])
|
||||
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ACTPolicy])
|
||||
def test_policy_defaults(policy_cls):
|
||||
kwargs = {}
|
||||
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.
|
||||
|
||||
Reference in New Issue
Block a user