diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index b339fe775..676a4a925 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -79,8 +79,8 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset class EpisodeSampler(torch.utils.data.Sampler): def __init__(self, dataset: LeRobotDataset, episode_index: int): - from_idx = dataset.meta.episodes["dataset_from_index"][episode_index].item() - to_idx = dataset.meta.episodes["dataset_to_index"][episode_index].item() + from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] + to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] self.frame_ids = range(from_idx, to_idx) def __iter__(self) -> Iterator: @@ -283,7 +283,7 @@ def main(): tolerance_s = kwargs.pop("tolerance_s") logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s) + dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s) visualize_dataset(dataset, **vars(args))