Improves Type Annotations (#252)

This commit is contained in:
Wael Karkoub
2024-06-10 19:09:48 +01:00
committed by GitHub
parent a06598678c
commit 54c9776bde
7 changed files with 54 additions and 23 deletions

View File

@@ -66,6 +66,7 @@ import argparse
import json
import shutil
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import HfApi
@@ -77,7 +78,7 @@ from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_r
from lerobot.common.datasets.utils import flatten_dict
def get_from_raw_to_lerobot_format_fn(raw_format):
def get_from_raw_to_lerobot_format_fn(raw_format: str):
if raw_format == "pusht_zarr":
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
elif raw_format == "umi_zarr":
@@ -96,7 +97,9 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
return from_raw_to_lerobot_format
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
def save_meta_data(
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
):
meta_data_dir.mkdir(parents=True, exist_ok=True)
# save info
@@ -114,7 +117,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir):
save_file(episode_data_index, ep_data_idx_path)
def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
"""Expect all meta data files to be all stored in a single "meta_data" directory.
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
"""
@@ -128,7 +131,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
)
def push_videos_to_hub(repo_id, videos_dir, revision):
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
"""Expect mp4 files to be all stored in a single "videos" directory.
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
"""
@@ -209,6 +212,7 @@ def push_dataset_to_hub(
save_meta_data(info, stats, episode_data_index, meta_data_dir)
if not dry_run:
# TODO(rcadene): token needs to be a str | None
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)