Remove Prod, Tests are passind

This commit is contained in:
Cadene
2024-04-19 23:18:45 +00:00
parent 35a573c98e
commit c20cf2fbbc
12 changed files with 96 additions and 110 deletions

View File

@@ -51,8 +51,10 @@ print(f"{hf_dataset.features=}")
# display useful statistics about frames and episodes, which are sequences of frames from the same video
print(f"number of frames: {len(hf_dataset)=}")
print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
print(f"number of episodes: {len(hf_dataset.unique('episode_index'))=}")
print(
f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_index')):.3f}"
)
# select the frames belonging to episode number 5
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)

View File

@@ -63,8 +63,9 @@ dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_inde
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
frames = [sample["observation.image"] for sample in dataset]
# but frames are now channel first to follow pytorch convention,
# to view them, we convert to channel last
# but frames are now float32 range [0,1] channel first to follow pytorch convention,
# to view them, we convert to uint8 range [0,255] channel last
frames = [(frame * 255).type(torch.uint8) for frame in frames]
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# and finally save them to a mp4 video