Improve push_dataset_to_hub API + Add unit tests (#231)

Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Thomas Wolf
2024-06-13 15:18:02 +02:00
committed by GitHub
parent c38f535c9f
commit 125bd93e29
11 changed files with 750 additions and 419 deletions

View File

@@ -76,6 +76,7 @@ def require_env(func):
"""
Decorator that skips the test if the required environment package is not installed.
As it need 'env_name' in args, it also checks whether it is provided as an argument.
If 'env_name' is None, this check is skipped.
"""
@wraps(func)
@@ -91,7 +92,7 @@ def require_env(func):
# Perform the package check
package_name = f"gym_{env_name}"
if not is_package_available(package_name):
if env_name is not None and not is_package_available(package_name):
pytest.skip(f"gym-{env_name} not installed")
return func(*args, **kwargs)
@@ -99,6 +100,38 @@ def require_env(func):
return wrapper
def require_package_arg(func):
"""
Decorator that skips the test if the required package is not installed.
This is similar to `require_env` but more general in that it can check any package (not just environments).
As it need 'required_packages' in args, it also checks whether it is provided as an argument.
If 'required_packages' is None, this check is skipped.
"""
@wraps(func)
def wrapper(*args, **kwargs):
# Determine if 'required_packages' is provided and extract its value
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
if "required_packages" in arg_names:
# Get the index of 'required_packages' and retrieve the value from args
index = arg_names.index("required_packages")
required_packages = args[index] if len(args) > index else kwargs.get("required_packages")
else:
raise ValueError("Function does not have 'required_packages' as an argument.")
if required_packages is None:
return func(*args, **kwargs)
# Perform the package check
for package in required_packages:
if not is_package_available(package):
pytest.skip(f"{package} not installed")
return func(*args, **kwargs)
return wrapper
def require_package(package_name):
"""
Decorator that skips the test if the specified package is not installed.