[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
committed by AdilZouitine
parent 76df8a31b3
commit 38f5fa4523
79 changed files with 2782 additions and 788 deletions

View File

@@ -45,7 +45,9 @@ def concatenate_episodes(ep_dicts):
return data_dict
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
def save_images_concurrently(
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
@@ -55,7 +57,10 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
[
executor.submit(save_image, imgs_array[i], i, out_dir)
for i in range(num_images)
]
def get_default_encoding() -> dict:
@@ -64,7 +69,8 @@ def get_default_encoding() -> dict:
return {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
if v.default is not inspect.Parameter.empty
and k in ["vcodec", "pix_fmt", "g", "crf"]
}
@@ -77,7 +83,9 @@ def check_repo_id(repo_id: str) -> None:
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
def calculate_episode_data_index(
hf_dataset: datasets.Dataset,
) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.