Write episodes as jsonlines

This commit is contained in:
Simon Alibert
2024-10-17 10:17:27 +02:00
parent c146ba936f
commit 50a75ad3fe
3 changed files with 33 additions and 8 deletions

View File

@@ -93,6 +93,7 @@ import warnings
from pathlib import Path
import datasets
import jsonlines
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
@@ -160,6 +161,11 @@ def write_json(data: dict, fpath: Path) -> None:
json.dump(data, f, indent=4)
def write_jsonlines(data: dict, fpath: Path) -> None:
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
safetensor_path = input_dir / "stats.safetensors"
stats = load_file(safetensor_path)
@@ -617,7 +623,7 @@ def convert_dataset(
{"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices
]
write_json(episodes, v20_dir / "meta" / "episodes.json")
write_jsonlines(episodes, v20_dir / "meta" / "episodes.jsonl")
# Assemble metadata v2.0
metadata_v2_0 = {
@@ -648,6 +654,9 @@ def convert_dataset(
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="data",
@@ -681,6 +690,7 @@ def convert_dataset(
# - [X] Handle multitask datasets
# - [X] Handle hf hub repo limits (add chunks logic)
# - [X] Add test-branch
# - [X] Use jsonlines for episodes
# - [X] Add sanity checks (encoding, shapes)