Convert datasets to av1 encoding (#302)

This commit is contained in:
Simon Alibert
2024-07-22 20:08:59 +02:00
committed by GitHub
parent 461d5472d3
commit 0b21210d72
571 changed files with 988 additions and 1311 deletions

View File

@@ -101,7 +101,7 @@ from termcolor import colored
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
from lerobot.common.datasets.utils import calculate_episode_data_index
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy
@@ -479,6 +479,8 @@ def record_dataset(
"fps": fps,
"video": video,
}
if video:
info["encoding"] = get_default_encoding()
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,

View File

@@ -55,6 +55,7 @@ from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
from lerobot.common.datasets.utils import flatten_dict
@@ -140,14 +141,12 @@ def push_dataset_to_hub(
num_workers: int = 8,
episodes: list[int] | None = None,
force_override: bool = False,
resume: bool = False,
cache_dir: Path = Path("/tmp"),
tests_data_dir: Path | None = None,
encoding: dict | None = None,
):
# Check repo_id is well formated
if len(repo_id.split("/")) != 2:
raise ValueError(
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
)
check_repo_id(repo_id)
user_id, dataset_id = repo_id.split("/")
# Robustify when `raw_dir` is str instead of Path
@@ -173,7 +172,7 @@ def push_dataset_to_hub(
if local_dir.exists():
if force_override:
shutil.rmtree(local_dir)
else:
elif not resume:
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
meta_data_dir = local_dir / "meta_data"
@@ -191,7 +190,7 @@ def push_dataset_to_hub(
# convert dataset from original raw format to LeRobot format
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir, videos_dir, fps, video, episodes
raw_dir, videos_dir, fps, video, episodes, encoding
)
lerobot_dataset = LeRobotDataset.from_preloaded(
@@ -315,6 +314,12 @@ def main():
default=0,
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
)
parser.add_argument(
"--resume",
type=int,
default=0,
help="When set to 1, resumes a previous run.",
)
parser.add_argument(
"--tests-data-dir",
type=Path,