From 9b62c25f6c52e1b96de003ac64649958b660f893 Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Mon, 20 May 2024 22:04:04 +1000 Subject: [PATCH] Adds split_by_episodes to LeRobotDataset (#158) --- examples/4_calculate_validation_loss.py | 71 ++++++++++++++++++++++ lerobot/common/datasets/lerobot_dataset.py | 8 ++- lerobot/common/datasets/utils.py | 71 ++++++++++++++++++++++ tests/test_examples.py | 61 +++++++++++++++---- tests/test_utils.py | 52 +++++++++++++--- 5 files changed, 242 insertions(+), 21 deletions(-) create mode 100644 examples/4_calculate_validation_loss.py diff --git a/examples/4_calculate_validation_loss.py b/examples/4_calculate_validation_loss.py new file mode 100644 index 000000000..285184d25 --- /dev/null +++ b/examples/4_calculate_validation_loss.py @@ -0,0 +1,71 @@ +"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data. + +This technique can be useful for debugging and testing purposes, as well as identifying whether a policy +is learning effectively. + +Furthermore, relying on validation loss to evaluate performance is generally not considered a good practice, +especially in the context of imitation learning. The most reliable approach is to evaluate the policy directly +on the target environment, whether that be in simulation or the real world. +""" + +from pathlib import Path + +import torch +from huggingface_hub import snapshot_download + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy + +device = torch.device("cuda") + +# Download the diffusion policy for pusht environment +pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht")) +# OR uncomment the following to evaluate a policy from the local outputs/train folder. +# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") + +policy = DiffusionPolicy.from_pretrained(pretrained_policy_path) +policy.eval() +policy.to(device) + +# Set up the dataset. +delta_timestamps = { + # Load the previous image and state at -0.1 seconds before current frame, + # then load current image and state corresponding to 0.0 second. + "observation.image": [-0.1, 0.0], + "observation.state": [-0.1, 0.0], + # Load the previous action (-0.1), the next action to be executed (0.0), + # and 14 future actions with a 0.1 seconds spacing. All these actions will be + # used to calculate the loss. + "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], +} + +# Load the last 10 episodes of the dataset as a validation set. +# The `split` argument utilizes the `datasets` library's syntax for slicing datasets. +# For more information on the Slice API, please see: +# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits +val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps) + +# Create dataloader for evaluation. +val_dataloader = torch.utils.data.DataLoader( + val_dataset, + num_workers=4, + batch_size=64, + shuffle=False, + pin_memory=device != torch.device("cpu"), + drop_last=False, +) + +# Run validation loop. +loss_cumsum = 0 +n_examples_evaluated = 0 +for batch in val_dataloader: + batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} + output_dict = policy.forward(batch) + + loss_cumsum += output_dict["loss"].item() + n_examples_evaluated += batch["index"].shape[0] + +# Calculate the average loss over the validation set. +average_loss = loss_cumsum / n_examples_evaluated + +print(f"Average loss on validation set: {average_loss:.4f}") diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 6a3204f4d..057e47702 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -20,12 +20,14 @@ import datasets import torch from lerobot.common.datasets.utils import ( + calculate_episode_data_index, load_episode_data_index, load_hf_dataset, load_info, load_previous_and_future_frames, load_stats, load_videos, + reset_episode_index, ) from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos @@ -54,7 +56,11 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(rcadene, aliberts): implement faster transfer # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads self.hf_dataset = load_hf_dataset(repo_id, version, root, split) - self.episode_data_index = load_episode_data_index(repo_id, version, root) + if split == "train": + self.episode_data_index = load_episode_data_index(repo_id, version, root) + else: + self.episode_data_index = calculate_episode_data_index(self.hf_dataset) + self.hf_dataset = reset_episode_index(self.hf_dataset) self.stats = load_stats(repo_id, version, root) self.info = load_info(repo_id, version, root) if self.video: diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 5cdd5f7c0..207ccf7c1 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -15,6 +15,7 @@ # limitations under the License. import json from pathlib import Path +from typing import Dict import datasets import torch @@ -245,6 +246,76 @@ def load_previous_and_future_frames( return item +def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: + """ + Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. + + Parameters: + - hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index. + + Returns: + - episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys: + - "from": A tensor containing the starting index of each episode. + - "to": A tensor containing the ending index of each episode. + """ + episode_data_index = {"from": [], "to": []} + + current_episode = None + """ + The episode_index is a list of integers, each representing the episode index of the corresponding example. + For instance, the following is a valid episode_index: + [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2] + + Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and + ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this: + { + "from": [0, 3, 7], + "to": [3, 7, 12] + } + """ + for idx, episode_idx in enumerate(hf_dataset["episode_index"]): + if episode_idx != current_episode: + # We encountered a new episode, so we append its starting location to the "from" list + episode_data_index["from"].append(idx) + # If this is not the first episode, we append the ending location of the previous episode to the "to" list + if current_episode is not None: + episode_data_index["to"].append(idx) + # Let's keep track of the current episode index + current_episode = episode_idx + else: + # We are still in the same episode, so there is nothing for us to do here + pass + # We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list + episode_data_index["to"].append(idx + 1) + + for k in ["from", "to"]: + episode_data_index[k] = torch.tensor(episode_data_index[k]) + + return episode_data_index + + +def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: + """ + Reset the `episode_index` of the provided HuggingFace Dataset. + + `episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the + `episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0. + + This brings the `episode_index` to the required format. + """ + unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist() + episode_idx_to_reset_idx_mapping = { + ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs) + } + + def modify_ep_idx_func(example): + example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()] + return example + + hf_dataset = hf_dataset.map(modify_ep_idx_func) + return hf_dataset + + def cycle(iterable): """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. diff --git a/tests/test_examples.py b/tests/test_examples.py index de95a9915..9881e3fa8 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # TODO(aliberts): Mute logging for these tests +import io import subprocess import sys from pathlib import Path @@ -32,6 +33,11 @@ def _run_script(path): subprocess.run([sys.executable, path], check=True) +def _read_file(path): + with open(path) as file: + return file.read() + + def test_example_1(): path = "examples/1_load_lerobot_dataset.py" _run_script(path) @@ -39,18 +45,17 @@ def test_example_1(): @require_package("gym_pusht") -def test_examples_3_and_2(): +def test_examples_2_through_4(): """ Train a model with example 3, check the outputs. Evaluate the trained model with example 2, check the outputs. + Calculate the validation loss with example 4, check the outputs. """ - path = "examples/3_train_policy.py" + ### Test example 3 + file_contents = _read_file("examples/3_train_policy.py") - with open(path) as file: - file_contents = file.read() - - # Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers. + # Do fewer steps, use smaller batch, use CPU, and don't complicate things with dataloader workers. file_contents = _find_and_replace( file_contents, [ @@ -67,16 +72,17 @@ def test_examples_3_and_2(): for file_name in ["model.safetensors", "config.json"]: assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() - path = "examples/2_evaluate_pretrained_policy.py" + ### Test example 2 + file_contents = _read_file("examples/2_evaluate_pretrained_policy.py") - with open(path) as file: - file_contents = file.read() - - # Do less evals, use CPU, and use the local model. + # Do fewer evals, use CPU, and use the local model. file_contents = _find_and_replace( file_contents, [ - ('pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', ""), + ( + 'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', + "", + ), ( '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', @@ -89,3 +95,34 @@ def test_examples_3_and_2(): exec(file_contents, {}) assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists() + + ## Test example 4 + file_contents = _read_file("examples/4_calculate_validation_loss.py") + + # Run on a single example from the last episode, use CPU, and use the local model. + file_contents = _find_and_replace( + file_contents, + [ + ( + 'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', + "", + ), + ( + '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', + 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', + ), + ('split="train[24342:]"', 'split="train[-1:]"'), + ("num_workers=4", "num_workers=0"), + ('device = torch.device("cuda")', 'device = torch.device("cpu")'), + ("batch_size=64", "batch_size=1"), + ], + ) + + # Capture the output of the script + output_buffer = io.StringIO() + sys.stdout = output_buffer + exec(file_contents, {}) + printed_output = output_buffer.getvalue() + # Restore stdout to its original state + sys.stdout = sys.__stdout__ + assert "Average loss on validation set" in printed_output diff --git a/tests/test_utils.py b/tests/test_utils.py index bcdd95b4e..a7f770fb8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,20 +4,28 @@ from typing import Callable import numpy as np import pytest import torch +from datasets import Dataset +from lerobot.common.datasets.utils import ( + calculate_episode_data_index, + hf_transform_to_torch, + reset_episode_index, +) from lerobot.common.utils.utils import seeded_context, set_global_seed @pytest.mark.parametrize( "rand_fn", - [ - random.random, - np.random.random, - lambda: torch.rand(1).item(), - ] - + [lambda: torch.rand(1, device="cuda")] - if torch.cuda.is_available() - else [], + ( + [ + random.random, + np.random.random, + lambda: torch.rand(1).item(), + ] + + [lambda: torch.rand(1, device="cuda")] + if torch.cuda.is_available() + else [] + ), ) def test_seeding(rand_fn: Callable[[], int]): set_global_seed(0) @@ -36,3 +44,31 @@ def test_seeding(rand_fn: Callable[[], int]): c_ = rand_fn() # Check that `seeded_context` and `global_seed` give the same reproducibility. assert c_ == c + + +def test_calculate_episode_data_index(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) + assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) + + +def test_reset_episode_index(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [10, 10, 11, 12, 12, 12], + }, + ) + dataset.set_transform(hf_transform_to_torch) + correct_episode_index = [0, 0, 1, 2, 2, 2] + dataset = reset_episode_index(dataset) + assert dataset["episode_index"] == correct_episode_index