Hot fix to compute validation loss example test (#200)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-05-20 18:30:11 +02:00
committed by GitHub
parent 9b62c25f6c
commit c4da689171
3 changed files with 50 additions and 6 deletions

View File

@@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t
on the target environment, whether that be in simulation or the real world.
"""
import math
from pathlib import Path
import torch
@@ -39,11 +40,29 @@ delta_timestamps = {
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
# Load the last 10 episodes of the dataset as a validation set.
# The `split` argument utilizes the `datasets` library's syntax for slicing datasets.
# For more information on the Slice API, please see:
# Load the last 10% of episodes of the dataset as a validation set.
# - Load full dataset
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
# - Calculate train and val subsets
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
num_val_episodes = full_dataset.num_episodes - num_train_episodes
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
# - Get first frame index of the validation set
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
# - Load frames subset belonging to validation set using the `split` argument.
# It utilizes the `datasets` library's syntax for slicing datasets.
# For more information on the Slice API, please see:
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps)
train_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
)
val_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
# Create dataloader for evaluation.
val_dataloader = torch.utils.data.DataLoader(