From c1b4dae6d048597709e82f5094b81a2be9472517 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Mon, 17 Feb 2025 17:23:07 +0100 Subject: [PATCH] Add streaming --- lerobot/common/datasets/lerobot_dataset.py | 43 +++++++++++++++++----- lerobot/common/datasets/utils.py | 12 ++++++ 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 9483bf0a..0b94e4b9 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -50,6 +50,7 @@ from lerobot.common.datasets.utils import ( get_hf_features_from_features, get_hub_safe_version, hf_transform_to_torch, + item_to_torch, load_episodes, load_info, load_stats, @@ -214,6 +215,9 @@ class LeRobotDatasetMetadata: task_index = self.task_to_task_index.get(task, None) return task_index if task_index is not None else self.total_tasks + def html_root(self) -> str: + return f"https://huggingface.co/datasets/{self.repo_id}/resolve/main" + def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None: self.info["total_episodes"] += 1 self.info["total_frames"] += episode_length @@ -334,6 +338,7 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos: bool = True, local_files_only: bool = False, video_backend: str | None = None, + streaming: bool = False, ): """ 2 modes are available for instantiating this class, depending on 2 different use cases: @@ -431,6 +436,8 @@ class LeRobotDataset(torch.utils.data.Dataset): will be made. Defaults to False. video_backend (str | None, optional): Video backend to use for decoding videos. There is currently a single option which is the pyav decoder used by Torchvision. Defaults to pyav. + streaming (bool, optional): If set to True, don't download the data files. Instead, it streams the data + progressively while iterating on the dataset. Default to False. """ super().__init__() self.repo_id = repo_id @@ -440,10 +447,11 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episodes = episodes self.tolerance_s = tolerance_s self.video_backend = video_backend if video_backend else "pyav" - self.delta_indices = None self.local_files_only = local_files_only + self.streaming = streaming # Unused attributes + self.delta_indices = None self.image_writer = None self.episode_buffer = None @@ -456,16 +464,21 @@ class LeRobotDataset(torch.utils.data.Dataset): check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) # Load actual data - self.download_episodes(download_videos) + if not self.streaming: + self.download_episodes(download_videos) self.hf_dataset = self.load_hf_dataset() + if self.streaming: + self.hf_dataset_iter = iter(self.hf_dataset.shuffle(buffer_size=1000)) self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) # Check timestamps - check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) + if not self.streaming: + check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) # Setup delta_indices if self.delta_timestamps is not None: - check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) + if not self.streaming: + check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) # Available stats implies all videos have been encoded and dataset is iterable @@ -550,13 +563,14 @@ class LeRobotDataset(torch.utils.data.Dataset): """hf_dataset contains all the observations, states, actions, rewards, etc.""" if self.episodes is None: path = str(self.root / "data") - hf_dataset = load_dataset("parquet", data_dir=path, split="train") + hf_dataset = load_dataset("parquet", data_dir=path, split="train", streaming=self.streaming) else: files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] - hf_dataset = load_dataset("parquet", data_files=files, split="train") + hf_dataset = load_dataset("parquet", data_files=files, split="train", streaming=self.streaming) - # TODO(aliberts): hf_dataset.set_format("torch") - hf_dataset.set_transform(hf_transform_to_torch) + if not self.streaming: + # TODO(aliberts): hf_dataset.set_format("torch") + hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset @@ -632,7 +646,8 @@ class LeRobotDataset(torch.utils.data.Dataset): """ item = {} for vid_key, query_ts in query_timestamps.items(): - video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) + root = self.meta.html_root if self.streaming else self.root + video_path = Path(root) / self.meta.get_video_file_path(ep_idx, vid_key) frames = decode_video_frames_torchvision( video_path, query_ts, self.tolerance_s, self.video_backend ) @@ -649,7 +664,15 @@ class LeRobotDataset(torch.utils.data.Dataset): return self.num_frames def __getitem__(self, idx) -> dict: - item = self.hf_dataset[idx] + if self.streaming: + try: + item = next(self.hf_dataset_iter) + except StopIteration: + self.hf_dataset_iter = iter(self.hf_dataset.shuffle(buffer_size=1000)) + item = next(self.hf_dataset_iter) + item = item_to_torch(item) + else: + item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() query_indices = None diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 612bac39..865b4229 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -205,6 +205,18 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): return items_dict +def item_to_torch(item: dict): + for key, value in item.items(): + if isinstance(value, PILImage.Image): + to_tensor = transforms.ToTensor() + item[key] = to_tensor(value) + elif value is None or isinstance(value, str): + pass + else: + item[key] = torch.tensor(value) + return item + + def _get_major_minor(version: str) -> tuple[int]: split = version.strip("v").split(".") return int(split[0]), int(split[1])