Fix tests
This commit is contained in:
@@ -22,8 +22,6 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -87,7 +85,7 @@ def worker_process(queue: queue.Queue, num_threads: int):
|
||||
t.join()
|
||||
|
||||
|
||||
class ImageWriter:
|
||||
class AsyncImageWriter:
|
||||
"""
|
||||
This class abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
@@ -102,11 +100,7 @@ class ImageWriter:
|
||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
|
||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
|
||||
self.write_dir = write_dir
|
||||
self.write_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.image_path = DEFAULT_IMAGE_PATH
|
||||
|
||||
def __init__(self, num_processes: int = 0, num_threads: int = 1):
|
||||
self.num_processes = num_processes
|
||||
self.num_threads = num_threads
|
||||
self.queue = None
|
||||
@@ -134,17 +128,6 @@ class ImageWriter:
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = self.image_path.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.write_dir / fpath
|
||||
|
||||
def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self.get_image_file_path(
|
||||
episode_index=episode_index, image_key=image_key, frame_index=0
|
||||
).parent
|
||||
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
|
||||
@@ -22,15 +22,18 @@ from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
@@ -44,6 +47,7 @@ from lerobot.common.datasets.utils import (
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hf_features_from_features,
|
||||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
load_episodes,
|
||||
@@ -140,14 +144,9 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
""""""
|
||||
"""All features contained in the dataset."""
|
||||
return self.info["features"]
|
||||
|
||||
@property
|
||||
def keys(self) -> list[str]:
|
||||
"""Keys to access non-image data (state, actions etc.)."""
|
||||
return self.info["keys"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as images."""
|
||||
@@ -268,7 +267,7 @@ class LeRobotDatasetMetadata:
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot)
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
@@ -522,35 +521,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
|
||||
|
||||
@property
|
||||
def features(self) -> list[str]:
|
||||
return list(self._features) + self.meta.video_keys
|
||||
def features(self) -> dict[str, dict]:
|
||||
return self.meta.features
|
||||
|
||||
@property
|
||||
def _features(self) -> datasets.Features:
|
||||
def hf_features(self) -> datasets.Features:
|
||||
"""Features of the hf_dataset."""
|
||||
if self.hf_dataset is not None:
|
||||
return self.hf_dataset.features
|
||||
elif self.episode_buffer is None:
|
||||
raise NotImplementedError(
|
||||
"Dataset features must be infered from an existing hf_dataset or episode_buffer."
|
||||
)
|
||||
|
||||
features = {}
|
||||
for key in self.episode_buffer:
|
||||
if key in ["episode_index", "frame_index", "index", "task_index"]:
|
||||
features[key] = datasets.Value(dtype="int64")
|
||||
elif key in ["next.done", "next.success"]:
|
||||
features[key] = datasets.Value(dtype="bool")
|
||||
elif key in ["timestamp", "next.reward"]:
|
||||
features[key] = datasets.Value(dtype="float32")
|
||||
elif key in self.meta.image_keys:
|
||||
features[key] = datasets.Image()
|
||||
elif key in self.meta.keys:
|
||||
features[key] = datasets.Sequence(
|
||||
length=self.meta.shapes[key], feature=datasets.Value(dtype="float32")
|
||||
)
|
||||
|
||||
return datasets.Features(features)
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
@@ -650,17 +630,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
# TODO(aliberts): Handle resume
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
return {
|
||||
"size": 0,
|
||||
"episode_index": self.meta.total_episodes if episode_index is None else episode_index,
|
||||
"task_index": None,
|
||||
"frame_index": [],
|
||||
"timestamp": [],
|
||||
**{key: [] for key in self.meta.features},
|
||||
**{key: [] for key in self.meta.image_keys},
|
||||
**{key: [] if key != "episode_index" else current_ep_idx for key in self.features},
|
||||
}
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
write_image(image, fpath)
|
||||
else:
|
||||
self.image_writer.save_image(image=image, fpath=fpath)
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||
@@ -668,35 +657,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
then needs to be called.
|
||||
"""
|
||||
frame_index = self.episode_buffer["size"]
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(frame_index / self.fps)
|
||||
self.episode_buffer["next.done"].append(False)
|
||||
|
||||
# Save all observed modalities except images
|
||||
for key in self.meta.keys:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
for key, ft in self.features.items():
|
||||
if key == "frame_index":
|
||||
self.episode_buffer[key].append(frame_index)
|
||||
elif key == "timestamp":
|
||||
self.episode_buffer[key].append(frame_index / self.fps)
|
||||
elif key in frame and ft["dtype"] not in ["image", "video"]:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
elif key in frame and ft["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_image(frame[key], img_path)
|
||||
if ft["dtype"] == "image":
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
if self.image_writer is None:
|
||||
return
|
||||
|
||||
# Save images
|
||||
for cam_key in self.meta.camera_keys:
|
||||
img_path = self.image_writer.get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.image_writer.save_image(
|
||||
image=frame[cam_key],
|
||||
fpath=img_path,
|
||||
)
|
||||
|
||||
if cam_key in self.meta.image_keys:
|
||||
self.episode_buffer[cam_key].append(str(img_path))
|
||||
|
||||
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
@@ -714,23 +693,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
raise NotImplementedError()
|
||||
|
||||
task_index = self.meta.get_task_index(task)
|
||||
self.episode_buffer["next.done"][-1] = True
|
||||
|
||||
for key in self.episode_buffer:
|
||||
if key in self.meta.image_keys:
|
||||
continue
|
||||
elif key in self.meta.keys:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
if not set(self.episode_buffer.keys()) == set(self.features):
|
||||
raise ValueError()
|
||||
|
||||
for key, ft in self.features.items():
|
||||
if key == "index":
|
||||
self.episode_buffer[key] = np.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
elif key == "episode_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
||||
self.episode_buffer[key] = np.full((episode_length,), episode_index)
|
||||
elif key == "task_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), task_index)
|
||||
else:
|
||||
self.episode_buffer[key] = np.full((episode_length,), task_index)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif ft["shape"][0] == 1:
|
||||
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||
elif ft["shape"][0] > 1:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.episode_buffer["index"] = torch.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
self.meta.add_episode(episode_index, episode_length, task, task_index)
|
||||
|
||||
self._wait_image_writer()
|
||||
@@ -744,7 +728,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_index: int) -> None:
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.hf_features, split="train")
|
||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
write_parquet(ep_dataset, ep_data_path)
|
||||
@@ -753,7 +737,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if self.image_writer is not None:
|
||||
for cam_key in self.meta.camera_keys:
|
||||
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=cam_key, frame_index=0
|
||||
).parent
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
@@ -761,13 +747,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||
if isinstance(self.image_writer, ImageWriter):
|
||||
if isinstance(self.image_writer, AsyncImageWriter):
|
||||
logging.warning(
|
||||
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
|
||||
"You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset."
|
||||
)
|
||||
|
||||
self.image_writer = ImageWriter(
|
||||
write_dir=self.root / "images",
|
||||
self.image_writer = AsyncImageWriter(
|
||||
num_processes=num_processes,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
@@ -787,19 +772,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
for episode_index in range(self.meta.total_episodes):
|
||||
for key in self.meta.video_keys:
|
||||
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
||||
# to call self.image_writer here
|
||||
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
@@ -810,8 +797,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.encode_videos()
|
||||
self.meta.write_video_info()
|
||||
|
||||
if not keep_image_files and self.image_writer is not None:
|
||||
shutil.rmtree(self.image_writer.write_dir)
|
||||
if not keep_image_files:
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
@@ -989,7 +978,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys})
|
||||
features.update(
|
||||
{k: v for k, v in dataset.hf_features.items() if k not in self.disabled_data_keys}
|
||||
)
|
||||
return features
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import Any
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import pyarrow.compute as pc
|
||||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, HfApi
|
||||
@@ -39,6 +40,7 @@ TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
@@ -222,6 +224,24 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
||||
return version
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
else:
|
||||
assert len(ft["shape"]) == 1
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
|
||||
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
camera_ft = {}
|
||||
if robot.cameras:
|
||||
@@ -270,6 +290,31 @@ def get_episode_data_index(
|
||||
}
|
||||
|
||||
|
||||
def calculate_total_episode(
|
||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
||||
return total_episodes
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = []
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
|
||||
Reference in New Issue
Block a user