Add DEVICE constant from LEROBOT_TESTS_DEVICE

This commit is contained in:
Cadene
2024-03-12 14:14:39 +00:00
parent 29c73844b1
commit 5881eec376
5 changed files with 12 additions and 5 deletions

View File

@@ -1,3 +1,4 @@
import os
import pytest
from tensordict import TensorDict
import torch
@@ -8,7 +9,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.envs.pusht.env import PushtEnv
from lerobot.common.envs.simxarm import SimxarmEnv
from .utils import init_config
from .utils import DEVICE, init_config
def print_spec_rollout(env):
@@ -89,7 +90,7 @@ def test_pusht(from_pixels, pixels_only):
],
)
def test_factory(env_name):
cfg = init_config(overrides=[f"env={env_name}"])
cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"])
offline_buffer = make_offline_buffer(cfg)