Rename num_samples -> num_frames for consistency
This commit is contained in:
@@ -357,8 +357,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return self.info["names"]
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
"""Number of samples/frames in selected episodes."""
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames in selected episodes."""
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames
|
||||
|
||||
@property
|
||||
@@ -510,7 +510,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx) -> dict:
|
||||
item = self.hf_dataset[idx]
|
||||
@@ -544,7 +544,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Selected episodes: {self.episodes},\n"
|
||||
f" Number of selected episodes: {self.num_episodes},\n"
|
||||
f" Number of selected samples: {self.num_samples},\n"
|
||||
f" Number of selected samples: {self.num_frames},\n"
|
||||
f"\n{json.dumps(self.info, indent=4)}\n"
|
||||
)
|
||||
|
||||
@@ -981,9 +981,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
return video_frame_keys
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
def num_frames(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
return sum(d.num_samples for d in self._datasets)
|
||||
return sum(d.num_frames for d in self._datasets)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
@@ -1000,7 +1000,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
return 1 / self.fps - 1e-4
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
@@ -1009,8 +1009,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_samples:
|
||||
start_idx += dataset.num_samples
|
||||
if idx >= start_idx + dataset.num_frames:
|
||||
start_idx += dataset.num_frames
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
@@ -1028,7 +1028,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Samples: {self.num_frames},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
|
||||
Reference in New Issue
Block a user