From 0e98c6ee966d95393cec2928948793659eea4965 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 14 Mar 2025 18:53:42 +0300 Subject: [PATCH 1/4] Add torchcodec cpu (#798) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Remi Co-authored-by: Remi Co-authored-by: Simon Alibert Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- .github/workflows/test.yml | 2 +- benchmarks/video/run_video_benchmark.py | 2 +- lerobot/common/datasets/lerobot_dataset.py | 14 ++-- lerobot/common/datasets/video_utils.py | 98 ++++++++++++++++++++++ pyproject.toml | 1 + 5 files changed, 107 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3ef478877..d91c53646 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -126,7 +126,7 @@ jobs: # portaudio19-dev is needed to install pyaudio run: | sudo apt-get update && \ - sudo apt-get install -y libegl1-mesa-dev portaudio19-dev + sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev - name: Install uv and python uses: astral-sh/setup-uv@v5 diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index e90664872..c62578c46 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -67,7 +67,7 @@ def parse_int_or_none(value) -> int | None: def check_datasets_formats(repo_ids: list) -> None: for repo_id in repo_ids: dataset = LeRobotDataset(repo_id) - if dataset.video: + if len(dataset.meta.video_keys) > 0: raise ValueError( f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}" ) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 5414c76df..101e71f44 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -67,7 +67,7 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.datasets.video_utils import ( VideoFrame, - decode_video_frames_torchvision, + decode_video_frames, encode_video_frames, get_video_info, ) @@ -462,8 +462,8 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos (bool, optional): Flag to download the videos. Note that when set to True but the video files are already present on local disk, they won't be downloaded again. Defaults to True. - video_backend (str | None, optional): Video backend to use for decoding videos. There is currently - a single option which is the pyav decoder used by Torchvision. Defaults to pyav. + video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec. + You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. """ super().__init__() self.repo_id = repo_id @@ -473,7 +473,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episodes = episodes self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION - self.video_backend = video_backend if video_backend else "pyav" + self.video_backend = video_backend if video_backend else "torchcodec" self.delta_indices = None # Unused attributes @@ -707,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset): item = {} for vid_key, query_ts in query_timestamps.items(): video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames_torchvision( - video_path, query_ts, self.tolerance_s, self.video_backend - ) + frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend) item[vid_key] = frames.squeeze(0) return item @@ -1029,7 +1027,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_timestamps = None obj.delta_indices = None obj.episode_data_index = None - obj.video_backend = video_backend if video_backend is not None else "pyav" + obj.video_backend = video_backend if video_backend is not None else "torchcodec" return obj diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 9f043f966..3fe19d8b6 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -27,6 +27,35 @@ import torch import torchvision from datasets.features.features import register_feature from PIL import Image +from torchcodec.decoders import VideoDecoder + + +def decode_video_frames( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + backend: str = "torchcodec", +) -> torch.Tensor: + """ + Decodes video frames using the specified backend. + + Args: + video_path (Path): Path to the video file. + timestamps (list[float]): List of timestamps to extract frames. + tolerance_s (float): Allowed deviation in seconds for frame retrieval. + backend (str, optional): Backend to use for decoding. Defaults to "torchcodec". + + Returns: + torch.Tensor: Decoded frames. + + Currently supports torchcodec on cpu and pyav. + """ + if backend == "torchcodec": + return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) + elif backend in ["pyav", "video_reader"]: + return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) + else: + raise ValueError(f"Unsupported video backend: {backend}") def decode_video_frames_torchvision( @@ -127,6 +156,75 @@ def decode_video_frames_torchvision( return closest_frames +def decode_video_frames_torchcodec( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + device: str = "cpu", + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + """Loads frames associated with the requested timestamps of a video using torchcodec. + + Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors. + + Note: Video benefits from inter-frame compression. Instead of storing every frame individually, + the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to + that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame, + and all subsequent frames until reaching the requested frame. The number of key frames in a video + can be adjusted during encoding to take into account decoding time and video size in bytes. + """ + # initialize video decoder + decoder = VideoDecoder(video_path, device=device, seek_mode="approximate") + loaded_frames = [] + loaded_ts = [] + # get metadata for frame information + metadata = decoder.metadata + average_fps = metadata.average_fps + + # convert timestamps to frame indices + frame_indices = [round(ts * average_fps) for ts in timestamps] + + # retrieve frames based on indices + frames_batch = decoder.get_frames_at(indices=frame_indices) + + for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False): + loaded_frames.append(frame) + loaded_ts.append(pts.item()) + if log_loaded_timestamps: + logging.info(f"Frame loaded at timestamp={pts:.4f}") + + query_ts = torch.tensor(timestamps) + loaded_ts = torch.tensor(loaded_ts) + + # compute distances between each query timestamp and loaded timestamps + dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1) + min_, argmin_ = dist.min(1) + + is_within_tol = min_ < tolerance_s + assert is_within_tol.all(), ( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + "It means that the closest frame that can be loaded from the video is too far away in time." + "This might be due to synchronization issues with timestamps during data collection." + "To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + ) + + # get closest frames to the query timestamps + closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) + closest_ts = loaded_ts[argmin_] + + if log_loaded_timestamps: + logging.info(f"{closest_ts=}") + + # convert to float32 in [0,1] range (channel first) + closest_frames = closest_frames.type(torch.float32) / 255 + + assert len(timestamps) == len(closest_frames) + return closest_frames + + def encode_video_frames( imgs_dir: Path | str, video_path: Path | str, diff --git a/pyproject.toml b/pyproject.toml index 19a5cffa7..f1f836b4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ dependencies = [ "rerun-sdk>=0.21.0", "termcolor>=2.4.0", "torch>=2.2.1", + "torchcodec>=0.2.1", "torchvision>=0.21.0", "wandb>=0.16.3", "zarr>=2.17.0", From 7dc9ffe4c9625b4eff4bdfa5eecb17ea47b2d9fe Mon Sep 17 00:00:00 2001 From: Huan Liu Date: Sat, 15 Mar 2025 00:07:14 +0800 Subject: [PATCH 2/4] Update 10_use_so100.md (#840) --- examples/10_use_so100.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/10_use_so100.md b/examples/10_use_so100.md index b8b45aa5e..d24232299 100644 --- a/examples/10_use_so100.md +++ b/examples/10_use_so100.md @@ -583,6 +583,13 @@ Let's explain it: Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`. +To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so100_test` policy: +```bash +python lerobot/scripts/train.py \ + --config_path=outputs/train/act_so100_test/checkpoints/last/pretrained_model/train_config.json \ + --resume=true +``` + ## K. Evaluate your policy You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: From a3cd18eda97dbdea58dcc646a728f020101b6459 Mon Sep 17 00:00:00 2001 From: Huan Liu Date: Sat, 15 Mar 2025 16:40:39 +0800 Subject: [PATCH 3/4] =?UTF-8?q?added=20wandb.run=5Fid=20to=20allow=20resum?= =?UTF-8?q?ing=20without=20wandb=20log;=20updated=20log=20m=E2=80=A6=20(#8?= =?UTF-8?q?41)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- lerobot/common/utils/wandb_utils.py | 8 +++++++- lerobot/configs/default.py | 1 + lerobot/configs/train.py | 4 +++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index 9985b894c..700ebea5d 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -69,7 +69,13 @@ class WandBLogger: os.environ["WANDB_SILENT"] = "True" import wandb - wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None + wandb_run_id = ( + cfg.wandb.run_id + if cfg.wandb.run_id + else get_wandb_run_id_from_filesystem(self.log_dir) + if cfg.resume + else None + ) wandb.init( id=wandb_run_id, project=self.cfg.project, diff --git a/lerobot/configs/default.py b/lerobot/configs/default.py index 1e7f5819a..dee0649aa 100644 --- a/lerobot/configs/default.py +++ b/lerobot/configs/default.py @@ -46,6 +46,7 @@ class WandBConfig: project: str = "lerobot" entity: str | None = None notes: str | None = None + run_id: str | None = None @dataclass diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 2b147a5b3..7a787b83e 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -79,7 +79,9 @@ class TrainPipelineConfig(HubMixin): # The entire train config is already loaded, we just need to get the checkpoint dir config_path = parser.parse_arg("config_path") if not config_path: - raise ValueError("A config_path is expected when resuming a run.") + raise ValueError( + f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}" + ) if not Path(config_path).resolve().exists(): raise NotADirectoryError( f"{config_path=} is expected to be a local path. " From 9f0a8a49d0497d03b7d9c2cc840f476e6de8df99 Mon Sep 17 00:00:00 2001 From: Guillaume LEGENDRE Date: Sat, 15 Mar 2025 11:34:17 +0100 Subject: [PATCH 4/4] Update test-docker-build.yml --- .github/workflows/test-docker-build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-docker-build.yml b/.github/workflows/test-docker-build.yml index e77c570ea..c31025645 100644 --- a/.github/workflows/test-docker-build.yml +++ b/.github/workflows/test-docker-build.yml @@ -41,7 +41,7 @@ jobs: - name: Get changed files id: changed-files - uses: tj-actions/changed-files@v44 + uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42 with: files: docker/** json: "true"