Fix datasets missing versions (#318)

This commit is contained in:
Simon Alibert
2024-07-16 23:02:31 +02:00
committed by GitHub
parent 5f5efe7cb9
commit 8865e19c12
12 changed files with 156 additions and 120 deletions

View File

@@ -35,15 +35,16 @@ from lerobot.common.datasets.utils import (
)
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/codebase_version.md
CODEBASE_VERSION = "v1.5"
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_id: str,
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
image_transforms: Callable | None = None,
@@ -52,7 +53,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
):
super().__init__()
self.repo_id = repo_id
self.version = version
self.root = root
self.split = split
self.image_transforms = image_transforms
@@ -60,16 +60,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
if split == "train":
self.episode_data_index = load_episode_data_index(repo_id, version, root)
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
else:
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
self.hf_dataset = reset_episode_index(self.hf_dataset)
self.stats = load_stats(repo_id, version, root)
self.info = load_info(repo_id, version, root)
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
self.info = load_info(repo_id, CODEBASE_VERSION, root)
if self.video:
self.videos_dir = load_videos(repo_id, version, root)
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
self.video_backend = video_backend if video_backend is not None else "pyav"
@property
@@ -164,7 +164,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return (
f"{self.__class__.__name__}(\n"
f" Repository ID: '{self.repo_id}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
@@ -173,6 +172,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n"
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
f")"
)
@@ -180,7 +180,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
def from_preloaded(
cls,
repo_id: str = "from_preloaded",
version: str | None = CODEBASE_VERSION,
root: Path | None = None,
split: str = "train",
transform: callable = None,
@@ -204,7 +203,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# create an empty object of type LeRobotDataset
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.version = version
obj.root = root
obj.split = split
obj.image_transforms = transform
@@ -228,7 +226,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_ids: list[str],
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
image_transforms: Callable | None = None,
@@ -242,7 +239,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self._datasets = [
LeRobotDataset(
repo_id,
version=version,
root=root,
split=split,
delta_timestamps=delta_timestamps,
@@ -279,7 +275,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
)
self.disabled_data_keys.update(extra_keys)
self.version = version
self.root = root
self.split = split
self.image_transforms = image_transforms
@@ -395,7 +390,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"