Allow dataset creation without robot

This commit is contained in:
Simon Alibert
2024-10-24 00:13:21 +02:00
parent 0d77be90ee
commit 60865e8980
3 changed files with 66 additions and 23 deletions

View File

@@ -35,6 +35,7 @@ from lerobot.common.datasets.utils import (
INFO_PATH,
STATS_PATH,
TASKS_PATH,
_get_info_from_robot,
append_jsonl,
check_delta_timestamps,
check_timestamps_sync,
@@ -683,7 +684,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
self.image_writer = ImageWriter(
write_dir=self.root,
write_dir=self.root / "images",
num_processes=num_processes,
num_threads=num_threads,
)
@@ -734,6 +735,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
# TODO(aliberts)
# - [ ] add video info in info.json
# Sanity checks:
# - [ ] shapes
# - [ ] ep_lenghts
@@ -744,8 +746,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
cls,
repo_id: str,
fps: int,
robot: Robot,
root: Path | None = None,
robot: Robot | None = None,
robot_type: str | None = None,
keys: list[str] | None = None,
image_keys: list[str] | None = None,
video_keys: list[str] = None,
shapes: dict | None = None,
names: dict | None = None,
tolerance_s: float = 1e-4,
image_writer_processes: int = 0,
image_writer_threads_per_camera: int = 0,
@@ -757,26 +765,41 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj.tolerance_s = tolerance_s
obj.image_writer = None
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)
if robot is not None:
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
obj.start_image_writter(
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
)
elif (
robot_type is None
or keys is None
or image_keys is None
or video_keys is None
or shapes is None
or names is None
):
raise ValueError()
if len(video_keys) > 0 and not use_videos:
raise ValueError
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos)
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
)
write_json(obj.info, obj.root / INFO_PATH)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()
obj.image_writer = None
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
obj.start_image_writter(
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
)
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling