Add local_files_only, encode_videos, fix bugs to pass tests (WIP)
This commit is contained in:
@@ -30,6 +30,12 @@ from torchvision import transforms
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||
@@ -104,6 +110,32 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
return items_dict
|
||||
|
||||
|
||||
def _get_major_minor(version: str) -> tuple[int]:
|
||||
split = version.strip("v").split(".")
|
||||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||
)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
warnings.warn(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
codebase. The current codebase version is {current_version}. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
|
||||
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
||||
num_version = float(version.strip("v"))
|
||||
if num_version < 2 and enforce_v2:
|
||||
@@ -131,30 +163,28 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
||||
return version
|
||||
|
||||
|
||||
def load_metadata(local_dir: Path) -> tuple[dict | list]:
|
||||
"""Loads metadata files from a dataset."""
|
||||
info_path = local_dir / "meta/info.json"
|
||||
episodes_path = local_dir / "meta/episodes.jsonl"
|
||||
stats_path = local_dir / "meta/stats.json"
|
||||
tasks_path = local_dir / "meta/tasks.jsonl"
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
with open(local_dir / INFO_PATH) as f:
|
||||
return json.load(f)
|
||||
|
||||
with open(info_path) as f:
|
||||
info = json.load(f)
|
||||
|
||||
with jsonlines.open(episodes_path, "r") as reader:
|
||||
episode_dicts = list(reader)
|
||||
|
||||
with open(stats_path) as f:
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
with open(local_dir / STATS_PATH) as f:
|
||||
stats = json.load(f)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
with jsonlines.open(tasks_path, "r") as reader:
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
with jsonlines.open(local_dir / TASKS_PATH, "r") as reader:
|
||||
tasks = list(reader)
|
||||
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
stats = unflatten_dict(stats)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
||||
return info, episode_dicts, stats, tasks
|
||||
|
||||
def load_episode_dicts(local_dir: Path) -> dict:
|
||||
with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict:
|
||||
@@ -229,7 +259,7 @@ def check_timestamps_sync(
|
||||
# Track original indices before masking
|
||||
original_indices = torch.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze()
|
||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user