Adds split_by_episodes to LeRobotDataset (#158)
This commit is contained in:
@@ -20,12 +20,14 @@ import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_info,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
load_videos,
|
||||
reset_episode_index,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
|
||||
@@ -54,7 +56,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
||||
if split == "train":
|
||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
||||
else:
|
||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||
self.stats = load_stats(repo_id, version, root)
|
||||
self.info = load_info(repo_id, version, root)
|
||||
if self.video:
|
||||
|
||||
Reference in New Issue
Block a user