Loads episode_data_index and stats during dataset __init__ (#85)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -18,7 +18,10 @@ dataset = PushtDataset()
|
||||
```
|
||||
by one of these:
|
||||
```python
|
||||
dataset = XarmDataset()
|
||||
dataset = XarmDataset("xarm_lift_medium")
|
||||
dataset = XarmDataset("xarm_lift_medium_replay")
|
||||
dataset = XarmDataset("xarm_push_medium")
|
||||
dataset = XarmDataset("xarm_push_medium_replay")
|
||||
dataset = AlohaDataset("aloha_sim_insertion_human")
|
||||
dataset = AlohaDataset("aloha_sim_insertion_scripted")
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
|
||||
@@ -44,6 +47,7 @@ from lerobot.common.datasets.pusht import PushtDataset
|
||||
dataset = PushtDataset()
|
||||
|
||||
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
print(f"{dataset=}")
|
||||
print(f"{dataset.hf_dataset=}")
|
||||
|
||||
@@ -55,13 +59,16 @@ print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||
|
||||
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
||||
# TODO(rcadene): remove this example of accessing hf_dataset
|
||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||
|
||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
|
||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames.
|
||||
frames = [sample["observation.image"] for sample in dataset]
|
||||
|
||||
# but frames are now channel first to follow pytorch convention,
|
||||
# to view them, we convert to channel last
|
||||
# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention,
|
||||
# to view them, we convert to uint8 range [0,255]
|
||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||
# and to channel last (h,w,c)
|
||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||
|
||||
# and finally save them to a mp4 video
|
||||
|
||||
Reference in New Issue
Block a user