Add download_metadata, move default paths
This commit is contained in:
@@ -31,9 +31,7 @@ from lerobot.common.datasets.utils import (
|
||||
get_episode_data_index,
|
||||
get_hub_safe_version,
|
||||
load_hf_dataset,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
load_metadata,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
||||
|
||||
@@ -41,6 +39,12 @@ from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_
|
||||
CODEBASE_VERSION = "v2.0"
|
||||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||
)
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
@@ -70,7 +74,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Instantiating this class with this 'repo_id' will download the dataset from that address and load
|
||||
it, pending your dataset is compliant with codebase_version v2.0. If your dataset has been created
|
||||
before this new format, you will be prompted to convert it using our conversion script from v1.6
|
||||
to v2.0, which you can find at [TODO(aliberts): move conversion script & add location here].
|
||||
to v2.0, which you can find at lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
|
||||
|
||||
2. Your dataset already exists on your local disk in the 'root' folder:
|
||||
This is typically the case when you recorded your dataset locally and you may or may not have
|
||||
@@ -139,7 +143,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
||||
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
||||
multiples of 1/fps. Defaults to 1e-4.
|
||||
download_videos (bool, optional): Flag to download the videos. Defaults to True.
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
"""
|
||||
@@ -157,9 +163,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Load metadata
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
|
||||
self.info = load_info(repo_id, self._version, self.root)
|
||||
self.stats = load_stats(repo_id, self._version, self.root)
|
||||
self.tasks = load_tasks(repo_id, self._version, self.root)
|
||||
self.download_metadata()
|
||||
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes()
|
||||
@@ -185,6 +190,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# - [ ] Update episode_index (arg update=True)
|
||||
# - [ ] Update info.json (arg update=True)
|
||||
|
||||
def download_metadata(self) -> None:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._version,
|
||||
local_dir=self.root,
|
||||
allow_patterns="meta/",
|
||||
)
|
||||
|
||||
def download_episodes(self) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
@@ -227,11 +241,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
|
||||
|
||||
@property
|
||||
def episode_dicts(self) -> list[dict]:
|
||||
"""List of dictionary containing information for each episode, indexed by episode_index."""
|
||||
return self.info["episodes"]
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
@@ -254,7 +263,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return self.image_keys + self.video_keys
|
||||
|
||||
@property
|
||||
@@ -277,6 +286,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_chunks(self) -> int:
|
||||
"""Total number of chunks (groups of episodes)."""
|
||||
return self.info["total_chunks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of episodes per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
@@ -397,42 +416,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_preloaded(
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str = "from_preloaded",
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
# additional preloaded attributes
|
||||
hf_dataset=None,
|
||||
episode_data_index=None,
|
||||
stats=None,
|
||||
info=None,
|
||||
videos_dir=None,
|
||||
video_backend=None,
|
||||
tolerance_s: float = 1e-4,
|
||||
video_backend: str | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
|
||||
|
||||
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
|
||||
on the filesystem or uploading to the hub.
|
||||
|
||||
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
|
||||
meaningless depending on the downstream usage of the return dataset.
|
||||
"""
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
# create an empty object of type LeRobotDataset
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = root
|
||||
obj.split = split
|
||||
obj.image_transforms = transform
|
||||
obj.delta_timestamps = delta_timestamps
|
||||
obj.hf_dataset = hf_dataset
|
||||
obj.episode_data_index = episode_data_index
|
||||
obj.stats = stats
|
||||
obj.info = info if info is not None else {}
|
||||
obj.videos_dir = videos_dir
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
# obj.episodes = None
|
||||
# obj.image_transforms = None
|
||||
# obj.delta_timestamps = None
|
||||
# obj.episode_data_index = episode_data_index
|
||||
# obj.stats = stats
|
||||
# obj.info = info if info is not None else {}
|
||||
# obj.videos_dir = videos_dir
|
||||
# obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
return obj
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user