Rename num_samples -> num_frames for consistency

This commit is contained in:
Simon Alibert
2024-11-01 19:47:16 +01:00
parent 2650872b76
commit 79d114cc1f
8 changed files with 28 additions and 28 deletions

View File

@@ -180,13 +180,13 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
d.stats[data_key]["mean"] * (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.stats
)
@@ -195,12 +195,12 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
* (d.num_samples / total_samples)
* (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.stats
)

View File

@@ -357,8 +357,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
return self.info["names"]
@property
def num_samples(self) -> int:
"""Number of samples/frames in selected episodes."""
def num_frames(self) -> int:
"""Number of frames in selected episodes."""
return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames
@property
@@ -510,7 +510,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item
def __len__(self):
return self.num_samples
return self.num_frames
def __getitem__(self, idx) -> dict:
item = self.hf_dataset[idx]
@@ -544,7 +544,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
f" Repository ID: '{self.repo_id}',\n"
f" Selected episodes: {self.episodes},\n"
f" Number of selected episodes: {self.num_episodes},\n"
f" Number of selected samples: {self.num_samples},\n"
f" Number of selected samples: {self.num_frames},\n"
f"\n{json.dumps(self.info, indent=4)}\n"
)
@@ -981,9 +981,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
return video_frame_keys
@property
def num_samples(self) -> int:
def num_frames(self) -> int:
"""Number of samples/frames."""
return sum(d.num_samples for d in self._datasets)
return sum(d.num_frames for d in self._datasets)
@property
def num_episodes(self) -> int:
@@ -1000,7 +1000,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
return 1 / self.fps - 1e-4
def __len__(self):
return self.num_samples
return self.num_frames
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
@@ -1009,8 +1009,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_samples:
start_idx += dataset.num_samples
if idx >= start_idx + dataset.num_frames:
start_idx += dataset.num_frames
dataset_idx += 1
continue
break
@@ -1028,7 +1028,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Samples: {self.num_frames},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"

View File

@@ -187,7 +187,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
# Shift the incoming indices if necessary.
if self.num_samples > 0:
if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
@@ -227,11 +227,11 @@ class OnlineBuffer(torch.utils.data.Dataset):
)
@property
def num_samples(self) -> int:
def num_frames(self) -> int:
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
def __len__(self):
return self.num_samples
return self.num_frames
def _item_to_tensors(self, item: dict) -> dict:
item_ = {}