forked from tangger/lerobot
Tests cleaning & simplification (#81)
This commit is contained in:
@@ -1,6 +1,37 @@
|
||||
import os
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.utils.import_utils import is_package_available
|
||||
|
||||
# Pass this as the first argument to init_hydra_config.
|
||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||
|
||||
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def require_env(func):
|
||||
"""
|
||||
Decorator that skips the test if the required environment package is not installed.
|
||||
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Determine if 'env_name' is provided and extract its value
|
||||
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
|
||||
if "env_name" in arg_names:
|
||||
# Get the index of 'env_name' and retrieve the value from args
|
||||
index = arg_names.index("env_name")
|
||||
env_name = args[index] if len(args) > index else kwargs.get("env_name")
|
||||
else:
|
||||
raise ValueError("Function does not have 'env_name' as an argument.")
|
||||
|
||||
# Perform the package check
|
||||
package_name = f"gym_{env_name}"
|
||||
if not is_package_available(package_name):
|
||||
pytest.skip(f"gym-{env_name} not installed")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
Reference in New Issue
Block a user