forked from tangger/lerobot
Add add_episode & task logic
This commit is contained in:
@@ -19,16 +19,19 @@ from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.utils import (
|
||||
append_jsonl,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
@@ -160,6 +163,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.image_writer = image_writer
|
||||
self.episode_buffer = {}
|
||||
self.consolidated = True
|
||||
self.delta_indices = None
|
||||
|
||||
# Load metadata
|
||||
@@ -192,6 +196,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# - [ ] Update episode_index (arg update=True)
|
||||
# - [ ] Update info.json (arg update=True)
|
||||
|
||||
def push_to_repo(self, push_videos: bool = True) -> None:
|
||||
if not self.consolidated:
|
||||
raise RuntimeError(
|
||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
|
||||
"Please use the '.consolidate()' method first."
|
||||
)
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
||||
upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
folder_path=self.root,
|
||||
repo_type="dataset",
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
@@ -303,11 +325,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""Number of samples/frames in selected episodes."""
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes selected."""
|
||||
@@ -318,6 +335,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def total_chunks(self) -> int:
|
||||
"""Total number of chunks (groups of episodes)."""
|
||||
@@ -331,7 +358,46 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
self.info.get("shapes")
|
||||
return self.info["shapes"]
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
"""Shapes for the different features."""
|
||||
if self.hf_dataset is not None:
|
||||
return self.hf_dataset.features
|
||||
elif self.episode_buffer is None:
|
||||
raise NotImplementedError(
|
||||
"Dataset features must be infered from an existing hf_dataset or episode_buffer."
|
||||
)
|
||||
|
||||
features = {}
|
||||
for key in self.episode_buffer:
|
||||
if key in ["episode_index", "frame_index", "index", "task_index"]:
|
||||
features[key] = datasets.Value(dtype="int64")
|
||||
elif key in ["next.done", "next.success"]:
|
||||
features[key] = datasets.Value(dtype="bool")
|
||||
elif key in ["timestamp", "next.reward"]:
|
||||
features[key] = datasets.Value(dtype="float32")
|
||||
elif key in self.image_keys:
|
||||
features[key] = datasets.Image()
|
||||
elif key in self.keys:
|
||||
features[key] = datasets.Sequence(
|
||||
length=self.shapes[key], feature=datasets.Value(dtype="float32")
|
||||
)
|
||||
|
||||
return datasets.Features(features)
|
||||
|
||||
@property
|
||||
def task_to_task_index(self) -> dict:
|
||||
return {task: task_idx for task_idx, task in self.tasks.items()}
|
||||
|
||||
def get_task_index(self, task: str) -> int:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise creates a new task_index.
|
||||
"""
|
||||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
|
||||
def current_episode_index(self, idx: int) -> int:
|
||||
episode_index = self.hf_dataset["episode_index"][idx]
|
||||
@@ -447,12 +513,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f")"
|
||||
)
|
||||
|
||||
def _create_episode_buffer(self) -> dict:
|
||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
# TODO(aliberts): Handle resume
|
||||
return {
|
||||
"chunk": self.total_chunks,
|
||||
"episode_index": self.total_episodes,
|
||||
"size": 0,
|
||||
"episode_index": self.total_episodes if episode_index is None else episode_index,
|
||||
"task_index": None,
|
||||
"frame_index": [],
|
||||
"timestamp": [],
|
||||
"next.done": [],
|
||||
@@ -490,6 +556,92 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
file_path=img_path,
|
||||
)
|
||||
|
||||
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
the hub.
|
||||
|
||||
Use encode_videos if you want to encode videos during the saving of each episode. Otherwise,
|
||||
you can do it later during dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
"""
|
||||
episode_length = self.episode_buffer.pop("size")
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
task_index = self.get_task_index(task)
|
||||
self.episode_buffer["next.done"][-1] = True
|
||||
|
||||
for key in self.episode_buffer:
|
||||
if key in self.keys:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
elif key == "episode_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
||||
elif key == "task_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), task_index)
|
||||
else:
|
||||
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||
|
||||
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
||||
self._save_episode_table(episode_index)
|
||||
|
||||
if encode_videos:
|
||||
pass # TODO
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_index: int) -> None:
|
||||
features = self.features
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train")
|
||||
ep_table = ep_dataset._data.table
|
||||
ep_data_path = self.get_data_file_path(ep_index=episode_index, return_str=False)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(ep_table, ep_data_path)
|
||||
|
||||
def _save_episode_to_metadata(
|
||||
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||
) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
if task_index not in self.tasks:
|
||||
self.info["total_tasks"] += 1
|
||||
self.tasks[task_index] = task
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonl(task_dict, self.root / "meta/tasks.jsonl")
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
write_json(self.info, self.root / "meta/info.json")
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": [task],
|
||||
"length": episode_length,
|
||||
}
|
||||
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")
|
||||
|
||||
def delete_episode(self) -> None:
|
||||
pass # TODO
|
||||
|
||||
def consolidate(self) -> None:
|
||||
pass # TODO
|
||||
# Sanity checks:
|
||||
# - [ ] shapes
|
||||
# - [ ] ep_lenghts
|
||||
# - [ ] number of files
|
||||
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
|
||||
# - [ ] no remaining self.image_writer.dir
|
||||
self.consolidated = True
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@@ -508,6 +660,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._version = CODEBASE_VERSION
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = image_writer
|
||||
obj.hf_dataset = None
|
||||
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warn(
|
||||
@@ -515,12 +668,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
)
|
||||
|
||||
obj.tasks = {}
|
||||
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
||||
write_json(obj.info, obj.root / "meta/info.json")
|
||||
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj._create_episode_buffer()
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk.
|
||||
# It is used to know when certain operations are need (for instance, computing dataset statistics).
|
||||
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
|
||||
obj.consolidated = True
|
||||
|
||||
# obj.episodes = None
|
||||
# obj.image_transforms = None
|
||||
# obj.delta_timestamps = None
|
||||
|
||||
@@ -81,6 +81,11 @@ def write_json(data: dict, fpath: Path) -> None:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def append_jsonl(data: dict, fpath: Path) -> None:
|
||||
with jsonlines.open(fpath, "a") as writer:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
|
||||
Reference in New Issue
Block a user