Add file paths
This commit is contained in:
@@ -22,6 +22,7 @@ from typing import Callable
|
||||
import datasets
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
@@ -32,7 +33,7 @@ from lerobot.common.datasets.utils import (
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_hub_safe_version,
|
||||
load_hf_dataset,
|
||||
hf_transform_to_torch,
|
||||
load_metadata,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
||||
@@ -100,7 +101,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
│ ├── episodes.jsonl
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ └── tasks.json
|
||||
│ └── tasks.jsonl
|
||||
└── videos (optional)
|
||||
├── chunk-000
|
||||
│ ├── observation.images.laptop
|
||||
@@ -160,12 +161,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Load metadata
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
|
||||
self.download_metadata()
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes()
|
||||
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
|
||||
# Check timestamps
|
||||
@@ -187,13 +188,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# - [ ] Update episode_index (arg update=True)
|
||||
# - [ ] Update info.json (arg update=True)
|
||||
|
||||
def download_metadata(self) -> None:
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
ignore_patterns: list[str] | str | None = None,
|
||||
) -> None:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._version,
|
||||
local_dir=self.root,
|
||||
allow_patterns="meta/",
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
def download_episodes(self) -> None:
|
||||
@@ -207,26 +213,46 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
files = None
|
||||
ignore_patterns = None if self.download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [
|
||||
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
|
||||
if len(self.video_keys) > 0 and self.download_videos:
|
||||
video_files = [
|
||||
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||
self.get_video_file_path(ep_idx, vid_key)
|
||||
for vid_key in self.video_keys
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files += video_files
|
||||
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._version,
|
||||
local_dir=self.root,
|
||||
allow_patterns=files,
|
||||
ignore_patterns=ignore_patterns,
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
path = str(self.root / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def get_data_file_path(self, ep_index: int, return_str: bool = True) -> str | Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.data_path.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes
|
||||
)
|
||||
return str(self.root / fpath) if return_str else self.root / fpath
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str, return_str: bool = True) -> str | Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return str(self.root / fpath) if return_str else self.root / fpath
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
ep_chunk = ep_index // self.chunks_size
|
||||
if ep_index > 0 and ep_index % self.chunks_size == 0:
|
||||
ep_chunk -= 1
|
||||
return ep_chunk
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
@@ -355,7 +381,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||
video_path = self.root / self.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames_torchvision(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
@@ -436,6 +462,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.write_info()
|
||||
obj.fps = fps
|
||||
|
||||
if not all(cam.fps == fps for cam in robot.cameras):
|
||||
logging.warn(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
)
|
||||
|
||||
# obj.episodes = None
|
||||
# obj.image_transforms = None
|
||||
# obj.delta_timestamps = None
|
||||
|
||||
Reference in New Issue
Block a user