From 0f0113a7a6b5a59ac92245296f3bbac89d6f7496 Mon Sep 17 00:00:00 2001 From: Cadene Date: Wed, 10 Apr 2024 16:03:39 +0000 Subject: [PATCH] print_cuda_memory_usage docstring --- README.md | 15 +++++++-------- lerobot/common/utils.py | 1 + 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 51e03d65..6473e1eb 100644 --- a/README.md +++ b/README.md @@ -120,27 +120,26 @@ wandb login You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities: ```python """ Copy pasted from `examples/1_visualize_dataset.py` """ +import os +from pathlib import Path + import lerobot from lerobot.common.datasets.aloha import AlohaDataset -from torchrl.data.replay_buffers import SamplerWithoutReplacement from lerobot.scripts.visualize_dataset import render_dataset print(lerobot.available_datasets) # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium'] -# we use this sampler to sample 1 frame after the other -sampler = SamplerWithoutReplacement(shuffle=False) - -dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler) +# TODO(rcadene): remove DATA_DIR +dataset = AlohaDataset("aloha_sim_transfer_cube_human", root=Path(os.environ.get("DATA_DIR"))) video_paths = render_dataset( dataset, out_dir="outputs/visualize_dataset/example", - max_num_samples=300, - fps=50, + max_num_episodes=1, ) print(video_paths) -# >>> ['outputs/visualize_dataset/example/episode_0.mp4'] +# ['outputs/visualize_dataset/example/episode_0_top.mp4'] ``` Or you can achieve the same result by executing our script from the command line: diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index e3e22832..373a3bbc 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -98,6 +98,7 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D def print_cuda_memory_usage(): + """Use this function to locate and debug memory leak.""" import gc gc.collect()