id -> index, finish moving compute_stats before hf_dataset push_to_hub

This commit is contained in:
Cadene
2024-04-19 10:33:42 +00:00
parent 64b09ea7a7
commit 714a776277
9 changed files with 120 additions and 99 deletions

View File

@@ -49,7 +49,7 @@ print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
# select the frames belonging to episode number 5
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# load all frames of episode 5 in RAM in PIL format
frames = hf_dataset["observation.image"]

View File

@@ -55,7 +55,7 @@ print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}")
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
frames = [sample["observation.image"] for sample in dataset]