Add DEVICE constant from LEROBOT_TESTS_DEVICE

This commit is contained in:
Cadene
2024-03-12 14:14:39 +00:00
parent 29c73844b1
commit 5881eec376
5 changed files with 12 additions and 5 deletions

View File

@@ -3,7 +3,7 @@ import torch
from lerobot.common.datasets.factory import make_offline_buffer
from .utils import init_config
from .utils import DEVICE, init_config
@pytest.mark.parametrize(
@@ -20,7 +20,7 @@ from .utils import init_config
],
)
def test_factory(env_name, dataset_id):
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"])
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"])
offline_buffer = make_offline_buffer(cfg)
for key in offline_buffer.image_keys:
img = offline_buffer[0].get(key)