Add extra info to dataset card, various fixes from Remi's review

This commit is contained in:
Simon Alibert
2024-11-18 17:50:13 +01:00
parent 4d15861872
commit a91b7c6163
5 changed files with 250 additions and 82 deletions

View File

@@ -213,8 +213,11 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
assert isinstance(ft.feature, datasets.Value)
dtype = ft.feature.dtype
shape = (ft.length,)
names = robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
assert len(names) == shape[0]
motor_names = (
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
)
assert len(motor_names) == shape[0]
names = {"motors": motor_names}
elif isinstance(ft, datasets.Image):
dtype = "image"
image = dataset[0][key] # Assuming first row
@@ -433,6 +436,9 @@ def convert_dataset(
tasks_path: Path | None = None,
tasks_col: Path | None = None,
robot_config: dict | None = None,
license: str | None = None,
citation: str | None = None,
arxiv: str | None = None,
test_branch: str | None = None,
):
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
@@ -559,7 +565,9 @@ def convert_dataset(
}
write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0)
card = create_lerobot_dataset_card(
tags=repo_tags, info=metadata_v2_0, license=license, citation=citation, arxiv=arxiv
)
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
@@ -634,6 +642,12 @@ def main():
default=None,
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
)
parser.add_argument(
"--license",
type=str,
default="mit",
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
)
parser.add_argument(
"--test-branch",
type=str,
@@ -652,7 +666,4 @@ def main():
if __name__ == "__main__":
from time import sleep
sleep(1)
main()