Fix tests

This commit is contained in:
Simon Alibert
2024-11-05 19:09:12 +01:00
parent aed9f4036a
commit f3630ad910
13 changed files with 437 additions and 496 deletions

View File

@@ -22,8 +22,6 @@ import numpy as np
import PIL.Image
import torch
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
def safe_stop_image_writer(func):
def wrapper(*args, **kwargs):
@@ -87,7 +85,7 @@ def worker_process(queue: queue.Queue, num_threads: int):
t.join()
class ImageWriter:
class AsyncImageWriter:
"""
This class abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
@@ -102,11 +100,7 @@ class ImageWriter:
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
self.write_dir = write_dir
self.write_dir.mkdir(parents=True, exist_ok=True)
self.image_path = DEFAULT_IMAGE_PATH
def __init__(self, num_processes: int = 0, num_threads: int = 1):
self.num_processes = num_processes
self.num_threads = num_threads
self.queue = None
@@ -134,17 +128,6 @@ class ImageWriter:
p.start()
self.processes.append(p)
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = self.image_path.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self.write_dir / fpath
def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
return self.get_image_file_path(
episode_index=episode_index, image_key=image_key, frame_index=0
).parent
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time

View File

@@ -22,15 +22,18 @@ from pathlib import Path
from typing import Callable
import datasets
import numpy as np
import PIL.Image
import torch
import torch.utils
from datasets import load_dataset
from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import (
DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH,
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
@@ -44,6 +47,7 @@ from lerobot.common.datasets.utils import (
get_delta_indices,
get_episode_data_index,
get_features_from_robot,
get_hf_features_from_features,
get_hub_safe_version,
hf_transform_to_torch,
load_episodes,
@@ -140,14 +144,9 @@ class LeRobotDatasetMetadata:
@property
def features(self) -> dict[str, dict]:
""""""
"""All features contained in the dataset."""
return self.info["features"]
@property
def keys(self) -> list[str]:
"""Keys to access non-image data (state, actions etc.)."""
return self.info["keys"]
@property
def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images."""
@@ -268,7 +267,7 @@ class LeRobotDatasetMetadata:
obj.root = root if root is not None else LEROBOT_HOME / repo_id
if robot is not None:
features = get_features_from_robot(robot)
features = get_features_from_robot(robot, use_videos)
robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
@@ -522,35 +521,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
@property
def features(self) -> list[str]:
return list(self._features) + self.meta.video_keys
def features(self) -> dict[str, dict]:
return self.meta.features
@property
def _features(self) -> datasets.Features:
def hf_features(self) -> datasets.Features:
"""Features of the hf_dataset."""
if self.hf_dataset is not None:
return self.hf_dataset.features
elif self.episode_buffer is None:
raise NotImplementedError(
"Dataset features must be infered from an existing hf_dataset or episode_buffer."
)
features = {}
for key in self.episode_buffer:
if key in ["episode_index", "frame_index", "index", "task_index"]:
features[key] = datasets.Value(dtype="int64")
elif key in ["next.done", "next.success"]:
features[key] = datasets.Value(dtype="bool")
elif key in ["timestamp", "next.reward"]:
features[key] = datasets.Value(dtype="float32")
elif key in self.meta.image_keys:
features[key] = datasets.Image()
elif key in self.meta.keys:
features[key] = datasets.Sequence(
length=self.meta.shapes[key], feature=datasets.Value(dtype="float32")
)
return datasets.Features(features)
else:
return get_hf_features_from_features(self.features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx]
@@ -650,17 +630,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
# TODO(aliberts): Handle resume
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
return {
"size": 0,
"episode_index": self.meta.total_episodes if episode_index is None else episode_index,
"task_index": None,
"frame_index": [],
"timestamp": [],
**{key: [] for key in self.meta.features},
**{key: [] for key in self.meta.image_keys},
**{key: [] if key != "episode_index" else current_ep_idx for key in self.features},
}
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self.root / fpath
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
if self.image_writer is None:
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
write_image(image, fpath)
else:
self.image_writer.save_image(image=image, fpath=fpath)
def add_frame(self, frame: dict) -> None:
"""
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
@@ -668,35 +657,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
then needs to be called.
"""
frame_index = self.episode_buffer["size"]
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(frame_index / self.fps)
self.episode_buffer["next.done"].append(False)
# Save all observed modalities except images
for key in self.meta.keys:
self.episode_buffer[key].append(frame[key])
for key, ft in self.features.items():
if key == "frame_index":
self.episode_buffer[key].append(frame_index)
elif key == "timestamp":
self.episode_buffer[key].append(frame_index / self.fps)
elif key in frame and ft["dtype"] not in ["image", "video"]:
self.episode_buffer[key].append(frame[key])
elif key in frame and ft["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
self._save_image(frame[key], img_path)
if ft["dtype"] == "image":
self.episode_buffer[key].append(str(img_path))
self.episode_buffer["size"] += 1
if self.image_writer is None:
return
# Save images
for cam_key in self.meta.camera_keys:
img_path = self.image_writer.get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
self.image_writer.save_image(
image=frame[cam_key],
fpath=img_path,
)
if cam_key in self.meta.image_keys:
self.episode_buffer[cam_key].append(str(img_path))
def add_episode(self, task: str, encode_videos: bool = False) -> None:
"""
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
@@ -714,23 +693,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
raise NotImplementedError()
task_index = self.meta.get_task_index(task)
self.episode_buffer["next.done"][-1] = True
for key in self.episode_buffer:
if key in self.meta.image_keys:
continue
elif key in self.meta.keys:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
if not set(self.episode_buffer.keys()) == set(self.features):
raise ValueError()
for key, ft in self.features.items():
if key == "index":
self.episode_buffer[key] = np.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
elif key == "episode_index":
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
self.episode_buffer[key] = np.full((episode_length,), episode_index)
elif key == "task_index":
self.episode_buffer[key] = torch.full((episode_length,), task_index)
else:
self.episode_buffer[key] = np.full((episode_length,), task_index)
elif ft["dtype"] in ["image", "video"]:
continue
elif ft["shape"][0] == 1:
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
elif ft["shape"][0] > 1:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
else:
raise ValueError()
self.episode_buffer["index"] = torch.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
self.meta.add_episode(episode_index, episode_length, task, task_index)
self._wait_image_writer()
@@ -744,7 +728,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.consolidated = False
def _save_episode_table(self, episode_index: int) -> None:
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.hf_features, split="train")
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
write_parquet(ep_dataset, ep_data_path)
@@ -753,7 +737,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index = self.episode_buffer["episode_index"]
if self.image_writer is not None:
for cam_key in self.meta.camera_keys:
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
img_dir = self._get_image_file_path(
episode_index=episode_index, image_key=cam_key, frame_index=0
).parent
if img_dir.is_dir():
shutil.rmtree(img_dir)
@@ -761,13 +747,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer = self._create_episode_buffer()
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
if isinstance(self.image_writer, ImageWriter):
if isinstance(self.image_writer, AsyncImageWriter):
logging.warning(
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
"You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset."
)
self.image_writer = ImageWriter(
write_dir=self.root / "images",
self.image_writer = AsyncImageWriter(
num_processes=num_processes,
num_threads=num_threads,
)
@@ -787,19 +772,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.image_writer.wait_until_done()
def encode_videos(self) -> None:
# Use ffmpeg to convert frames stored as png into mp4 videos
"""
Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
for episode_index in range(self.meta.total_episodes):
for key in self.meta.video_keys:
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
# to call self.image_writer here
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
img_dir = self._get_image_file_path(
episode_index=episode_index, image_key=key, frame_index=0
).parent
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset()
@@ -810,8 +797,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.encode_videos()
self.meta.write_video_info()
if not keep_image_files and self.image_writer is not None:
shutil.rmtree(self.image_writer.write_dir)
if not keep_image_files:
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
@@ -989,7 +978,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys})
features.update(
{k: v for k, v in dataset.hf_features.items() if k not in self.disabled_data_keys}
)
return features
@property

View File

@@ -22,6 +22,7 @@ from typing import Any
import datasets
import jsonlines
import pyarrow.compute as pc
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, HfApi
@@ -39,6 +40,7 @@ TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DATASET_CARD_TEMPLATE = """
---
@@ -222,6 +224,24 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
return version
def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"])
else:
assert len(ft["shape"]) == 1
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
return datasets.Features(hf_features)
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
camera_ft = {}
if robot.cameras:
@@ -270,6 +290,31 @@ def get_episode_data_index(
}
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lenghts = list(accumulate(episode_lengths))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def check_timestamps_sync(
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],