From f52f4f2cd2975686f8f8037d8396544712231475 Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Wed, 15 May 2024 12:13:09 +0200 Subject: [PATCH 1/4] Add copyrights (#157) --- lerobot/__init__.py | 15 +++++++++++++++ lerobot/__version__.py | 15 +++++++++++++++ .../_video_benchmark/run_video_benchmark.py | 15 +++++++++++++++ lerobot/common/datasets/factory.py | 15 +++++++++++++++ lerobot/common/datasets/lerobot_dataset.py | 15 +++++++++++++++ .../_diffusion_policy_replay_buffer.py | 15 +++++++++++++++ .../push_dataset_to_hub/_download_raw.py | 15 +++++++++++++++ .../_umi_imagecodecs_numcodecs.py | 15 +++++++++++++++ .../push_dataset_to_hub/aloha_hdf5_format.py | 15 +++++++++++++++ .../push_dataset_to_hub/compute_stats.py | 15 +++++++++++++++ .../push_dataset_to_hub/pusht_zarr_format.py | 15 +++++++++++++++ .../push_dataset_to_hub/umi_zarr_format.py | 15 +++++++++++++++ .../common/datasets/push_dataset_to_hub/utils.py | 15 +++++++++++++++ .../push_dataset_to_hub/xarm_pkl_format.py | 15 +++++++++++++++ lerobot/common/datasets/utils.py | 15 +++++++++++++++ lerobot/common/datasets/video_utils.py | 15 +++++++++++++++ lerobot/common/envs/factory.py | 15 +++++++++++++++ lerobot/common/envs/utils.py | 15 +++++++++++++++ lerobot/common/logger.py | 15 +++++++++++++++ lerobot/common/policies/act/configuration_act.py | 15 +++++++++++++++ lerobot/common/policies/act/modeling_act.py | 15 +++++++++++++++ .../diffusion/configuration_diffusion.py | 16 ++++++++++++++++ .../policies/diffusion/modeling_diffusion.py | 16 ++++++++++++++++ lerobot/common/policies/factory.py | 15 +++++++++++++++ lerobot/common/policies/normalize.py | 15 +++++++++++++++ lerobot/common/policies/policy_protocol.py | 15 +++++++++++++++ .../common/policies/tdmpc/configuration_tdmpc.py | 16 ++++++++++++++++ lerobot/common/policies/tdmpc/modeling_tdmpc.py | 16 ++++++++++++++++ lerobot/common/policies/utils.py | 15 +++++++++++++++ lerobot/common/utils/import_utils.py | 15 +++++++++++++++ lerobot/common/utils/io_utils.py | 15 +++++++++++++++ lerobot/common/utils/utils.py | 15 +++++++++++++++ lerobot/scripts/display_sys_info.py | 15 +++++++++++++++ lerobot/scripts/eval.py | 15 +++++++++++++++ lerobot/scripts/push_dataset_to_hub.py | 15 +++++++++++++++ lerobot/scripts/train.py | 15 +++++++++++++++ lerobot/scripts/visualize_dataset.py | 15 +++++++++++++++ tests/conftest.py | 15 +++++++++++++++ tests/scripts/save_dataset_to_safetensors.py | 15 +++++++++++++++ tests/scripts/save_policy_to_safetensor.py | 15 +++++++++++++++ tests/test_available.py | 15 +++++++++++++++ tests/test_datasets.py | 15 +++++++++++++++ tests/test_envs.py | 15 +++++++++++++++ tests/test_examples.py | 15 +++++++++++++++ tests/test_policies.py | 15 +++++++++++++++ tests/test_visualize_dataset.py | 15 +++++++++++++++ tests/utils.py | 15 +++++++++++++++ 47 files changed, 709 insertions(+) diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 072f4bc73..e188bc525 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library. We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables. diff --git a/lerobot/__version__.py b/lerobot/__version__.py index 6232b699d..d12aafaa9 100644 --- a/lerobot/__version__.py +++ b/lerobot/__version__.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """To enable `lerobot.__version__`""" from importlib.metadata import PackageNotFoundError, version diff --git a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py index 85d48fcfd..8be251dc1 100644 --- a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py +++ b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import random import shutil diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 22dd1789f..78967db6a 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import torch diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index f7bc5bd2b..21d098793 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os from pathlib import Path diff --git a/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py b/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py index 2f5326508..33b4c9745 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/) Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script. diff --git a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py index d26f3d236..232fd0558 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ This file contains all obsolete download scripts. They are centralized here to not have to load useless dependencies when using datasets. diff --git a/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py b/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py index 1561fb886..a118b7e78 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # imagecodecs/numcodecs.py # Copyright (c) 2021-2022, Christoph Gohlke diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py index f51a59cd7..4efadc9e0 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act """ diff --git a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py b/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py index a7a952fb9..ec2966582 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py +++ b/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from copy import deepcopy from math import ceil diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py index 0c3a8d19c..8133a36af 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy""" import shutil diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py index 008287506..cab2bdc52 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface""" import logging diff --git a/lerobot/common/datasets/push_dataset_to_hub/utils.py b/lerobot/common/datasets/push_dataset_to_hub/utils.py index 1b12c0b7e..4feb1dcfd 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/utils.py +++ b/lerobot/common/datasets/push_dataset_to_hub/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from concurrent.futures import ThreadPoolExecutor from pathlib import Path diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py index 686edf4ca..899ebdde7 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Process pickle files formatted like in: https://github.com/fyhMer/fowm""" import pickle diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 96b8fbbc9..5cdd5f7c0 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json from pathlib import Path diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 0252be2ee..edfca918e 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import subprocess import warnings diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index c5fd46711..83f94cfea 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib import gymnasium as gym diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 5370d3857..8fce03698 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import einops import numpy as np import torch diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index ea8db050e..109f69511 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # TODO(rcadene, alexander-soare): clean this file """Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py""" diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index a3980b14d..95f443da1 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass, field diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index a795d87b0..e85a37360 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Action Chunking Transformer Policy As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 28a514ab3..d0554942b 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass, field diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 3115160fb..c67040b6b 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" TODO(alexander-soare): diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index a819d18ff..4c124b617 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import inspect import logging diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index ab57c8ba2..d638c5416 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import torch from torch import Tensor, nn diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py index b00cff5ca..38738a909 100644 --- a/lerobot/common/policies/policy_protocol.py +++ b/lerobot/common/policies/policy_protocol.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """A protocol that all policies should follow. This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index 00d00913d..ddf52248a 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass, field diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 1fba43d08..70e78c980 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Implementation of Finetuning Offline World Models in the Real World. The comments in this code may sometimes refer to these references: diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index b23c13366..8f7b6eecd 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import torch from torch import nn diff --git a/lerobot/common/utils/import_utils.py b/lerobot/common/utils/import_utils.py index 642e0ff17..cd5f82450 100644 --- a/lerobot/common/utils/import_utils.py +++ b/lerobot/common/utils/import_utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib import logging diff --git a/lerobot/common/utils/io_utils.py b/lerobot/common/utils/io_utils.py index 5d727bd74..b85f17c7a 100644 --- a/lerobot/common/utils/io_utils.py +++ b/lerobot/common/utils/io_utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import warnings import imageio diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 8fe621f46..d62507b59 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import os.path as osp import random diff --git a/lerobot/scripts/display_sys_info.py b/lerobot/scripts/display_sys_info.py index e4ea4260c..4d8b48504 100644 --- a/lerobot/scripts/display_sys_info.py +++ b/lerobot/scripts/display_sys_info.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import platform import huggingface_hub diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index e4a9bfefe..9c95633a1 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Evaluate a policy on an environment by running rollouts and computing metrics. Usage examples: diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index dfac410b3..16d890a79 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub, or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 7319e03fc..ab07695b9 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import time from copy import deepcopy diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index d4fafe673..58da6a47e 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. Note: The last frame of the episode doesnt always correspond to a final state. diff --git a/tests/conftest.py b/tests/conftest.py index 856ca4555..62f831aa3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .utils import DEVICE diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py index 17cf2b38f..554efe758 100644 --- a/tests/scripts/save_dataset_to_safetensors.py +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py index 29e9a34f9..e79a94ff9 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensor.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import shutil from pathlib import Path diff --git a/tests/test_available.py b/tests/test_available.py index ead9296a7..db5bd520a 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib import gymnasium as gym diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 1d93d48f5..afea16a5e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import logging from copy import deepcopy diff --git a/tests/test_envs.py b/tests/test_envs.py index f172a6458..aec9999da 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib import gymnasium as gym diff --git a/tests/test_examples.py b/tests/test_examples.py index 543eb022f..de95a9915 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # TODO(aliberts): Mute logging for these tests import subprocess import sys diff --git a/tests/test_policies.py b/tests/test_policies.py index f0fa7c563..c84578541 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import inspect from pathlib import Path diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py index 0124afd3f..999540402 100644 --- a/tests/test_visualize_dataset.py +++ b/tests/test_visualize_dataset.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from lerobot.scripts.visualize_dataset import visualize_dataset diff --git a/tests/utils.py b/tests/utils.py index 74e3ba8f5..ba49ee706 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import platform from functools import wraps From 68c1b13406068b9d88afbfcb2366f927141514f3 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 16 May 2024 13:51:53 +0100 Subject: [PATCH 2/4] Make policies compatible with other/multiple image keys (#149) --- .../common/policies/act/configuration_act.py | 7 --- lerobot/common/policies/act/modeling_act.py | 46 +++++++------------ .../diffusion/configuration_diffusion.py | 17 +++++-- .../policies/diffusion/modeling_diffusion.py | 34 +++++++++----- .../policies/tdmpc/configuration_tdmpc.py | 12 +++-- .../common/policies/tdmpc/modeling_tdmpc.py | 20 ++++---- lerobot/common/policies/utils.py | 4 ++ lerobot/scripts/train.py | 3 +- tests/test_policies.py | 33 +++++++++++++ 9 files changed, 107 insertions(+), 69 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 95f443da1..be444b06b 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -145,10 +145,3 @@ class ACTConfig: raise ValueError( f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" ) - # Check that there is only one image. - # TODO(alexander-soare): generalize this to multiple images. - if ( - sum(k.startswith("observation.images.") for k in self.input_shapes) != 1 - or "observation.images.top" not in self.input_shapes - ): - raise ValueError('For now, only "observation.images.top" is accepted for an image input.') diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index e85a37360..4a8df1cee 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -62,6 +62,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): if config is None: config = ACTConfig() self.config = config + self.normalize_inputs = Normalize( config.input_shapes, config.input_normalization_modes, dataset_stats ) @@ -71,8 +72,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): self.unnormalize_outputs = Unnormalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) + self.model = ACT(config) + self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + + self.reset() + def reset(self): """This should be called whenever the environment is reset.""" if self.config.n_action_steps is not None: @@ -86,13 +92,10 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ - assert "observation.images.top" in batch - assert "observation.state" in batch - self.eval() batch = self.normalize_inputs(batch) - self._stack_images(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) if len(self._action_queue) == 0: # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue @@ -108,8 +111,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch = self.normalize_targets(batch) - self._stack_images(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( @@ -132,21 +135,6 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): return loss_dict - def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - """Stacks all the images in a batch and puts them in a new key: "observation.images". - - This function expects `batch` to have (at least): - { - "observation.state": (B, state_dim) batch of robot states. - "observation.images.{name}": (B, C, H, W) tensor of images. - } - """ - # Stack images in the order dictated by input_shapes. - batch["observation.images"] = torch.stack( - [batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")], - dim=-4, - ) - class ACT(nn.Module): """Action Chunking Transformer: The underlying neural network for ACTPolicy. @@ -176,10 +164,10 @@ class ACT(nn.Module): │ encoder │ │ │ │Transf.│ │ │ │ │ │ │encoder│ │ └───▲─────┘ │ │ │ │ │ - │ │ │ └───▲───┘ │ - │ │ │ │ │ - inputs └─────┼─────┘ │ - │ │ + │ │ │ └▲──▲─▲─┘ │ + │ │ │ │ │ │ │ + inputs └─────┼──┘ │ image emb. │ + │ state emb. │ └───────────────────────┘ """ @@ -321,18 +309,18 @@ class ACT(nn.Module): all_cam_features.append(cam_features) all_cam_pos_embeds.append(cam_pos_embed) # Concatenate camera observation feature maps and positional embeddings along the width dimension. - encoder_in = torch.cat(all_cam_features, axis=3) - cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) + encoder_in = torch.cat(all_cam_features, axis=-1) + cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) # Get positional embeddings for robot state and latent. - robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) - latent_embed = self.encoder_latent_input_proj(latent_sample) + robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) + latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) # Stack encoder input and positional embeddings moving to (S, B, C). encoder_in = torch.cat( [ torch.stack([latent_embed, robot_state_embed], axis=0), - encoder_in.flatten(2).permute(2, 0, 1), + einops.rearrange(encoder_in, "b c h w -> (h w) b c"), ] ) pos_embed = torch.cat( diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index d0554942b..632f6cd69 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -148,14 +148,21 @@ class DiffusionConfig: raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) + # There should only be one image key. + image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} + if len(image_keys) != 1: + raise ValueError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) + image_key = next(iter(image_keys)) if ( - self.crop_shape[0] > self.input_shapes["observation.image"][1] - or self.crop_shape[1] > self.input_shapes["observation.image"][2] + self.crop_shape[0] > self.input_shapes[image_key][1] + or self.crop_shape[1] > self.input_shapes[image_key][2] ): raise ValueError( - f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} ' - f'for `crop_shape` and {self.input_shapes["observation.image"]} for ' - '`input_shapes["observation.image"]`.' + f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " + f"for `crop_shape` and {self.input_shapes[image_key]} for " + "`input_shapes[{image_key}]`." ) supported_prediction_types = ["epsilon", "sample"] if self.prediction_type not in supported_prediction_types: diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index c67040b6b..1659b68eb 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -19,6 +19,7 @@ TODO(alexander-soare): - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. + - Make compatible with multiple image keys. """ import math @@ -83,10 +84,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): self.diffusion = DiffusionModel(config) + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + # Note: This check is covered in the post-init of the config but have a sanity check just in case. + if len(image_keys) != 1: + raise NotImplementedError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) + self.input_image_key = image_keys[0] + + self.reset() + def reset(self): - """ - Clear observation and action queues. Should be called on `env.reset()` - """ + """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { "observation.image": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps), @@ -115,16 +124,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ - assert "observation.image" in batch - assert "observation.state" in batch - batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: # stack n latest observations from the queue - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} actions = self.diffusion.generate_actions(batch) # TODO(rcadene): make above methods return output dictionary? @@ -138,6 +145,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} @@ -215,13 +223,12 @@ class DiffusionModel(nn.Module): def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: """ - This function expects `batch` to have (at least): + This function expects `batch` to have: { "observation.state": (B, n_obs_steps, state_dim) "observation.image": (B, n_obs_steps, C, H, W) } """ - assert set(batch).issuperset({"observation.state", "observation.image"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] assert n_obs_steps == self.config.n_obs_steps @@ -345,9 +352,12 @@ class DiffusionRgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. - # The dummy input should take the number of image channels from `config.input_shapes` and it should use the - # height and width from `config.crop_shape`. - dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape)) + # The dummy input should take the number of image channels from `config.input_shapes` and it should + # use the height and width from `config.crop_shape`. + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + assert len(image_keys) == 1 + image_key = image_keys[0] + dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape)) with torch.inference_mode(): dummy_feature_map = self.backbone(dummy_input) feature_map_shape = tuple(dummy_feature_map.shape[1:]) diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index ddf52248a..cf76fb08a 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -147,12 +147,18 @@ class TDMPCConfig: def __post_init__(self): """Input validation (not exhaustive).""" - if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]: + # There should only be one image key. + image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} + if len(image_keys) != 1: + raise ValueError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) + image_key = next(iter(image_keys)) + if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]: # TODO(alexander-soare): This limitation is solely because of code in the random shift # augmentation. It should be able to be removed. raise ValueError( - "Only square images are handled now. Got image shape " - f"{self.input_shapes['observation.image']}." + f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}." ) if self.n_gaussian_samples <= 0: raise ValueError( diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 70e78c980..7c873bf23 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -112,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): config.output_shapes, config.output_normalization_modes, dataset_stats ) - def save(self, fp): - """Save state dict of TOLD model to filepath.""" - torch.save(self.state_dict(), fp) + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + # Note: This check is covered in the post-init of the config but have a sanity check just in case. + assert len(image_keys) == 1 + self.input_image_key = image_keys[0] - def load(self, fp): - """Load a saved state dict from filepath into current agent.""" - self.load_state_dict(torch.load(fp)) + self.reset() def reset(self): """ @@ -137,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): @torch.no_grad() def select_action(self, batch: dict[str, Tensor]): """Select a single action given environment observations.""" - assert "observation.image" in batch - assert "observation.state" in batch - batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] self._queues = populate_queues(self._queues, batch) @@ -319,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): device = get_device_from_parameters(self) batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] batch = self.normalize_targets(batch) info = {} - # TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation. - batch_size = batch["index"].shape[0] - # (b, t) -> (t, b) for key in batch: if batch[key].ndim > 1: @@ -353,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): # Run latent rollout using the latent dynamics model and policy model. # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # gives us a next `z`. + batch_size = batch["index"].shape[0] z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) z_preds[0] = self.model.encode(current_observation) reward_preds = torch.empty_like(reward, device=device) diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index 8f7b6eecd..5a62daa2a 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -19,6 +19,10 @@ from torch import nn def populate_queues(queues, batch): for key in batch: + # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the + # queues have the keys they want). + if key not in queues: + continue if len(queues[key]) != queues[key].maxlen: # initialize by copying the first observation several times until the queue is full while len(queues[key]) != queues[key].maxlen: diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index ab07695b9..7ca7a0b3c 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -23,6 +23,7 @@ import hydra import torch from datasets import concatenate_datasets from datasets.utils import disable_progress_bars, enable_progress_bars +from omegaconf import DictConfig from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle @@ -307,7 +308,7 @@ def add_episodes_inplace( sampler.num_samples = len(concat_dataset) -def train(cfg: dict, out_dir=None, job_name=None): +def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() if job_name is None: diff --git a/tests/test_policies.py b/tests/test_policies.py index c84578541..75633fe66 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -64,6 +64,14 @@ def test_get_policy_and_config_classes(policy_name: str): "act", ["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"], ), + # Note: these parameters also need custom logic in the test function for overriding the Hydra config. + ( + "aloha", + "diffusion", + ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"], + ), + # Note: these parameters also need custom logic in the test function for overriding the Hydra config. + ("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]), ], ) @require_env @@ -87,6 +95,31 @@ def test_policy(env_name, policy_name, extra_overrides): + extra_overrides, ) + # Additional config override logic. + if env_name == "aloha" and policy_name == "diffusion": + for keys in [ + ("training", "delta_timestamps"), + ("policy", "input_shapes"), + ("policy", "input_normalization_modes"), + ]: + dct = dict(cfg[keys[0]][keys[1]]) + dct["observation.images.top"] = dct["observation.image"] + del dct["observation.image"] + cfg[keys[0]][keys[1]] = dct + cfg.override_dataset_stats = None + + # Additional config override logic. + if env_name == "pusht" and policy_name == "act": + for keys in [ + ("policy", "input_shapes"), + ("policy", "input_normalization_modes"), + ]: + dct = dict(cfg[keys[0]][keys[1]]) + dct["observation.image"] = dct["observation.images.top"] + del dct["observation.images.top"] + cfg[keys[0]][keys[1]] = dct + cfg.override_dataset_stats = None + # Check that we can make the policy object. dataset = make_dataset(cfg) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) From c9069df9f1e09a98f193eacc7241adead2d10553 Mon Sep 17 00:00:00 2001 From: Akshay Kashyap Date: Thu, 16 May 2024 10:34:10 -0400 Subject: [PATCH 3/4] Port SpatialSoftmax and remove Robomimic dependency (#182) Co-authored-by: Alexander Soare --- .../policies/diffusion/modeling_diffusion.py | 74 +++++++- poetry.lock | 179 +----------------- pyproject.toml | 1 - .../pusht_diffusion/actions.safetensors | Bin 4600 -> 4600 bytes .../pusht_diffusion/grad_stats.safetensors | Bin 47424 -> 47424 bytes .../pusht_diffusion/output_dict.safetensors | Bin 68 -> 68 bytes .../pusht_diffusion/param_stats.safetensors | Bin 49120 -> 49120 bytes 7 files changed, 75 insertions(+), 179 deletions(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 1659b68eb..2ae03f221 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -17,7 +17,6 @@ """Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" TODO(alexander-soare): - - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. - Make compatible with multiple image keys. """ @@ -27,13 +26,13 @@ from collections import deque from typing import Callable import einops +import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torchvision from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from huggingface_hub import PyTorchModelHubMixin -from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig @@ -312,6 +311,77 @@ class DiffusionModel(nn.Module): return loss.mean() +class SpatialSoftmax(nn.Module): + """ + Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. + (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. + + At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass" + of activations of each channel, i.e., keypoints in the image space for the policy to focus on. + + Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2): + ----------------------------------------------------- + | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) | + | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) | + | ... | ... | ... | ... | + | (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) | + ----------------------------------------------------- + This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot + product with the coordinates (120x2) to get expected points of maximal activation (512x2). + + The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally + provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable + linear mapping (in_channels, H, W) -> (num_kp, H, W). + """ + + def __init__(self, input_shape, num_kp=None): + """ + Args: + input_shape (list): (C, H, W) input feature map shape. + num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input. + """ + super().__init__() + + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._out_c = num_kp + else: + self.nets = None + self._out_c = self._in_c + + # we could use torch.linspace directly but that seems to behave slightly differently than numpy + # and causes a small degradation in pc_success of pre-trained models. + pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() + pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() + # register as buffer so it's moved to the correct device. + self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1)) + + def forward(self, features: Tensor) -> Tensor: + """ + Args: + features: (B, C, H, W) input feature maps. + Returns: + (B, K, 2) image-space coordinates of keypoints. + """ + if self.nets is not None: + features = self.nets(features) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + features = features.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(features, dim=-1) + # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions + expected_xy = attention @ self.pos_grid + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._out_c, 2) + + return feature_keypoints + + class DiffusionRgbEncoder(nn.Module): """Encoder an RGB image into a 1D feature vector. diff --git a/poetry.lock b/poetry.lock index 388e03f40..e0b27f159 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4,7 +4,7 @@ name = "absl-py" version = "2.1.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, @@ -767,16 +767,6 @@ files = [ [package.dependencies] six = ">=1.4.0" -[[package]] -name = "egl-probe" -version = "1.0.2" -description = "" -optional = false -python-versions = "*" -files = [ - {file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"}, -] - [[package]] name = "einops" version = "0.8.0" @@ -1037,64 +1027,6 @@ files = [ [package.extras] preview = ["glfw-preview"] -[[package]] -name = "grpcio" -version = "1.63.0" -description = "HTTP/2-based RPC framework" -optional = false -python-versions = ">=3.8" -files = [ - {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, - {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, - {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, - {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, - {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, - {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, - {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, - {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, - {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, - {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, - {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, - {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, - {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, - {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, - {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, - {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, - {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, - {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, - {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, - {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, - {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, -] - -[package.extras] -protobuf = ["grpcio-tools (>=1.63.0)"] - [[package]] name = "gym-aloha" version = "0.1.1" @@ -1668,7 +1600,6 @@ files = [ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:9e2addd2d1866fe112bc6f80117bcc6bc25191c5ed1bfbcf9f1386a884252ae8"}, {file = "lxml-5.2.1-cp37-cp37m-win32.whl", hash = "sha256:f51969bac61441fd31f028d7b3b45962f3ecebf691a510495e5d2cd8c8092dbd"}, {file = "lxml-5.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b0b58fbfa1bf7367dde8a557994e3b1637294be6cf2169810375caf8571a085c"}, - {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3e183c6e3298a2ed5af9d7a356ea823bccaab4ec2349dc9ed83999fd289d14d5"}, {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:804f74efe22b6a227306dd890eecc4f8c59ff25ca35f1f14e7482bbce96ef10b"}, {file = "lxml-5.2.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08802f0c56ed150cc6885ae0788a321b73505d2263ee56dad84d200cab11c07a"}, {file = "lxml-5.2.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8c09ed18ecb4ebf23e02b8e7a22a05d6411911e6fabef3a36e4f371f4f2585"}, @@ -1740,21 +1671,6 @@ html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] source = ["Cython (>=3.0.10)"] -[[package]] -name = "markdown" -version = "3.6" -description = "Python implementation of John Gruber's Markdown." -optional = false -python-versions = ">=3.8" -files = [ - {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, - {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, -] - -[package.extras] -docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] -testing = ["coverage", "pyyaml"] - [[package]] name = "markupsafe" version = "2.1.5" @@ -3056,6 +2972,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3224,30 +3141,6 @@ typing-extensions = ">=4.5" [package.extras] tests = ["pytest (==7.1.2)"] -[[package]] -name = "robomimic" -version = "0.2.0" -description = "robomimic: A Modular Framework for Robot Learning from Demonstration" -optional = false -python-versions = ">=3" -files = [ - {file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"}, -] - -[package.dependencies] -egl_probe = ">=1.0.1" -h5py = "*" -imageio = "*" -imageio-ffmpeg = "*" -numpy = ">=1.13.3" -psutil = "*" -tensorboard = "*" -tensorboardX = "*" -termcolor = "*" -torch = "*" -torchvision = "*" -tqdm = "*" - [[package]] name = "safetensors" version = "0.4.3" @@ -3738,55 +3631,6 @@ files = [ {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, ] -[[package]] -name = "tensorboard" -version = "2.16.2" -description = "TensorBoard lets you watch Tensors Flow" -optional = false -python-versions = ">=3.9" -files = [ - {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, -] - -[package.dependencies] -absl-py = ">=0.4" -grpcio = ">=1.48.2" -markdown = ">=2.6.8" -numpy = ">=1.12.0" -protobuf = ">=3.19.6,<4.24.0 || >4.24.0" -setuptools = ">=41.0.0" -six = ">1.9" -tensorboard-data-server = ">=0.7.0,<0.8.0" -werkzeug = ">=1.0.1" - -[[package]] -name = "tensorboard-data-server" -version = "0.7.2" -description = "Fast data loading for TensorBoard" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, - {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, - {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, -] - -[[package]] -name = "tensorboardx" -version = "2.6.2.2" -description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" -optional = false -python-versions = "*" -files = [ - {file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"}, - {file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"}, -] - -[package.dependencies] -numpy = "*" -packaging = "*" -protobuf = ">=3.20" - [[package]] name = "termcolor" version = "2.4.0" @@ -4064,23 +3908,6 @@ perf = ["orjson"] reports = ["pydantic (>=2.0.0)"] sweeps = ["sweeps (>=0.2.0)"] -[[package]] -name = "werkzeug" -version = "3.0.3" -description = "The comprehensive WSGI web application library." -optional = false -python-versions = ">=3.8" -files = [ - {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, - {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, -] - -[package.dependencies] -MarkupSafe = ">=2.1.1" - -[package.extras] -watchdog = ["watchdog (>=2.3)"] - [[package]] name = "xxhash" version = "3.4.1" @@ -4348,4 +4175,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "2f0d2cbf4a2dec546e25b29b9b108ff1f97b4c278b718360b3f7f6a2bf9dcef8" +content-hash = "e3e3c306a5519e4f716a1ac086ad9b734efedcac077a0ec71e5bc16349a1e559" diff --git a/pyproject.toml b/pyproject.toml index 24d9452d8..5b80d06f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ diffusers = "^0.27.2" torchvision = ">=0.18.0" h5py = ">=3.10.0" huggingface-hub = ">=0.21.4" -robomimic = "0.2.0" gymnasium = ">=0.29.1" cmake = ">=3.29.0.1" gym-pusht = { version = ">=0.1.3", optional = true} diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors index 730f5b2bc2a801d15b4ade3593c90f95650f5472..8f03990351292611f702c163ee387b3d7248b5f0 100644 GIT binary patch literal 4600 zcmb7HdsK~C8;?kl8b&kjn)4BI(3Q}bv-eM`!5}mx8pn{7GKiQiLYGl#GK3n1XneVq z3`We;lu`u=p*I%~iCd7j^M+rPcv{kHYf;6Fc+iD<5$ zXoJt5?>36&iWZuiiNvBE?owgqGRs12wpJ|KAoY=cv(?RQ$3~wWOll&wG&NfrAQqW^ zD3amCrY2?<0@&*&%l z8K~x9-1n?Uh2w1uHh#|IggVv2^@}vFQ6BHw2> zlA`}Au-~f>ZuUKBEiA{Jqu*oq&S+eEw+54^eMeu1#Gx=vhhZfoI4_%zHpj~O_+Wnr zt3^TrEJ{eXbq!hDzaEKpzw!FkYM6+C4=WO*FF}>h(LxvKWujOzZ7f5IAZbzx|1YghSu|LsQZbsC%>1R}= zrQzYi>Ab%}KRibapq}QBU~|xp``xxy4v!hWh_h@Z8OGx4-_sdga3l$ld0D)`9sO`b`Zf)>Lvn7^Z|HyFG#DSNEd^O_=cfeOw{G zLhVZWqfrZndRAg{QUMw-$8dgH{n?1O`Z0L7Ri9pWmZN;T8JCZ}bev@Bt9iLjdKV?S zr8pbZfV}mQWUM#<)&))6F6%1e0;_vK>r0%mm1Cn@HaI3uz$p($Vw^sS zMlUv@|DF5*R|>bn__JIXJ-^TKlI68L-)r@W)Lz!W>sT?n59Om*l2O0Nu}`{-j9>JO z%Y}GOS1Zt2^@z_O^tREWV|p-Ic8_YA^>Bam)KNIuHomtPX`e@AnR=e)|2NF1OvD>m z4<0nTKvAaPcGd@5@K|vI$pe&V+Y*C~rAnNxaw8q7y6_#-1>IlTF-g+_H5XgCy@T^D z9Clod)l-v^@NgeL7kyINu>E-$*Wb~X^Yzu4+J}H98B#Nqu$ONmjo~U3&@OBV>w{v~ zO>`~j!HO~JeXY1R5zf}lSeCn<6i_8@1Sqg0wi~x5%qh`!I(+39~QyMqd8M;!T$1z(w5p_~^3$l=a2$6AMUdbde)N+}H7~p48TXn0m?tO6g0MGa<$vUGNv|fq8(imp zc4Y`XRKHKkU0TpD?J++>a=Y_~Gis1T;(ZZP^BUZLo_!R@F}OO8pt_g_?8&BO zClx5!6iltl<}q6dIpQ>o=PvSmiin_d>V3`Rh36ZH<0iI${>fWB-ML>EtFyRtyAS*4U4>g^7Al*g;9>d>Q${so@A?|_ z{85LzOThhs&p4jPW4+wcRE__f-i8Dlf9U9JqZfW#Nh=)VQM`E%J(|-_HEr(F+`ac< zu<9I-(>pVrj&nYQB_FGx6xCsLFC8}-)ne0Nxh%eTI1T&^#y5zaITv0C;FYlmf`V->rPu=#JDLTwW6SM_!PdzO(Q z@;zAD*-1{0Jji4Gt2fJDz~?0^-Xd_CN_}p!Jh!^qk%NX22u+(r>pt~-@9)X4>T}CW z*Ov^F4P;!BkUi#=NG@RST!s_=`#U?3G4vtnY1tY4{tEA6Z*flC^jbvvXDe|q zE|}yf)cyX~zci#7l7tRA{V>mkzU|=ijM$oN41hlczfHdNfkh z=eK}KwsxX=&pw8>Jd}*LO~kR>ebmtH6?s+8IGYuR5yx&4uN*1#W*&xkVHt_N5RTmP zX~<1YgPC6{DhyNbs$P8ukskJh-IqU5$GAu;3mQpZTBh^9NE?NBxPL~7J{RJd20fuc^JV19>(|_VMv*uDGM4!i)}HA{*$8(?85x}tfVgz`n$Ur zKjt@}KK^$eM{8XUQuisa$RmvRt;S1*hoQonsn5}fZ!V|yhgMq^ z^oI$xONiE(Hu@&xZ?0QpyOX#y^kHwHHGM$ta=n0QFUk6+R}ymAQi;;25cSzspg>CM zBjkN5rm9g_F|w$P4oQzhaz+?hb}I2#Vr^WL5ga@6R8>2sG_cE5ppcMhM_p9+xowcX3$0}q*7huY6 zR!`SWk{DBf&@5i8)T->q>!Ce0%yk7fL5_H+MkyJl9u`;v8?oLB^A5z~zIG32O zrS2*u_D@#(X7vejhuX&|wk_ncfwxrH%Xofo-aGu!nga`GwwIYr$*?-Q?Bh;!S?t5& zdYiG=A)C)aiFphf8wYv&ez>ttxJ}Np_s` zM{leOFS|M{N&5{%JLO>>+iG@!MA^)hCkbKK8o_H0+8UeD*!0da4MX$DwcQG42KBK7Y2_Oan4Xh;d*wk9WqHeK)?g zq37(ix!v`&W z3w(td)aSP-u$=ollxRc$2;6`tn`C_MU3Y^zH6+oG8bfJKzhNi{Wcz}}6S*>u$RASM z=vN4r5}vmp*K-IZJ_f~PqUsd4o0(q&(ZEKs`9KAtBXeqIY7hiB%)%Va9P>=nK_WaNLY{X200 literal 4600 zcmb7Hd0dUzAC5^=A;d7YwArt;sD><^^AI&ng%Beu*|JPZB4o59LTO4dEeaK3k|^BL zaLZOn5-M8IK5bg=t>62eYntD``h5DF_q^wMzTfA&oO9mW)K7-}?N`@UpR-@xarZvw zjp}pMP4slt1?sL_?D^TR2D1dZwgPoW``z~If7`Ulb>nVV?ya^!Uq{#Wpg>*cTbUeB zprfrjizn-TE14$>boBKMc(&d*vpJ$bS9{h^JbTu+vU#FFSKmOJXX}46no4kopUeE- zi_3Yp(U>Hm9kq#c&d_r9e($jr;MdHrH_$tFR?r+xe;$8|5`bOtobrueaI1jOmh*5fOf_W+TRae{+OQ z(V%HnS;(``r02_ekT6pdYqsgZY+(l!BHO6^UrkV5n?yGi%*UriA;j%*4`f_KRIv9m zIy$@IYrF&>Y9w@vqcT3_T;sHe$lRw(zUD`dW;IWwJ_qed)8KaKnEoF&`?CvEw65eLu1|c-_1{m8ks)-XrF6FV*|`Mhx9)5=Q<>UyXrJRG@h~m-M^; z3tE;R#{Cs3RK`mT+3mK}aitw~ukON)4l&%G3TbIv0<~XW2VIFRRl5HN(J~ z+U01C9HR#;FTxq6#3EaxFJF&?-(a(=3o%tor5~nz!i)X_C{5c=Hz<^l{<&eyPL($x z(^5-{UM+>z>F@EI@D3TcDFWE_oUQY5c%{v!W^TXH3(w`Ss(KNzyxGO_9}zi>s`?b+ z!RdD-e4v2g6Bo5%&lWT8*_h~#kih)lioW692N|6|Gu)&@&q%mhJC>)2 zG4A~lcza0ulJC#@n+^DDk2@-dU*|MR@aB;xr&&ZYbY^3~0X3FS9;fW155HX>kX(CS z+E?5=m$ue;g43Kb3e%=qSi*AO>LO~@J+<6WT*Rka>K`=`nvMSl=lCp9xYyr1VK^~lybMSh)` zgOQ~bIFT(+HRsO522l^1^)IqGcwAm#DLEXylFgHwyXj2lIw+i_NVCj?Rjg8MSCZ(c!x)Zz)$Wu@`M>IPsQ3^#P>`&YVPJj7$ZUd`<+PrY)DOpwlI8a z(HW{&9}UH^UaSt6SHD5#*d36IaUjRquOM4vA?a1qK+?x}rYHB)C~WiSf)AIQ#&wmn z#@&e7xsoCwo&$U_H;n5$t{&{4tLvDbtGFDA$l!-!Hd>x3!wKcjq#~;fmlk+q;PlYG zcq|O8af>#xdG$O6Tr(WmXD22xjqy8L?aAv9?%d{lDBhK$8&3=A%jf5C-z1Z*Z`4hJ zLtrPcY@)>52fcuUe-C1b4`gx4U+av@2U%Io%O}n z^V8;t52ghEiNU|jrBTZ_qk4usZSxvQf6Sgl^SrJy{NdiOb>?H|cuLQ>yX-6fWxfn; z`@Rzk-EC>)TuujvDb49Y#*APZkwM6`xNf#@U(8TM&$3tu_di0e;!Ma_TR_D%TWW{n z2t(aqGhLN_s4rr*==W?7%e~DD2kLzG1W6w_0Jp}B)R*{Bjzl zlZv3XZZbq2!=ZCjs%74~Vsd?f7!5BTpuf>=gon(h2A+fIrInA+u|pnBmmbmeZT?WU zI?e2uPPqY%mh-G%tlhZ0aoU2lJK!+Ym_74xxBL51{SDm=|8E|oJ@J40mcvI4}8X+n5aVxr1e6 zHZtE%pAI4F)gzDzm+1I?m6-6Vlq}Zi#m@6~)cR-@3gTt^>Y_8LjQP76(g{E18fGuK zCJk$D6~gAwJ`$i8s+C~$eGgPN2jS(`DTqriWqA+2@Dnx2 zu0(?GF0lIRrOk>hz2WHpQ1%3ItxIL8NQg$U%-EpsQDkgz03jGW% zu4WHf@51HE0rK6*B8Z=!hwjlV{OI0<2}!v~*nN!ISnP5Ojk6OGxTlLcrTxI{$Bn2& zLg5Bfe#yn!fWyS7dwpL#nlZnVF)^>%JpJi%lx9|8dsHLqL&3+P)Ymr(OPpHBhN=^E zO44@hw4lV&^*P=p8NyPg*wu=fAl`YR)pY?y9mK<>M^KC}(@LZ0kn5Z2{9YXu(>l@I(mHAuF)$O9Qh% zZAdDMr?U1k?j$dye#V8=DmE5_CZ?g?Ae{03-7h)=OsL1lU>aoDB8}aR|CTT&a>ioZ z_YXw8rZZ0ee3fK9?a_EF5ujxFQeJn3-e{0~l!B(oGrV8_f zO9))8%yQUY5Xf|y+jZi0dJauh6Ou*VPq174FQ(HWfUkjw*d;gPqK7Q?%sx%WE-Ate z(O}#vucAkmHbY{zkWRmR0=2Z4v~Asw{X4n)m8+lG;xu~KYYDj*a1f^cJusMk3yvv1 zWKF9Hi}U8GU2I%AsRt1~T`)N*fy@3Q5Zk^+a`kB3*rZFR+n-_ntKT8~I9Y!>?~?>sV}~Nyn7eN||Ci^# z?bFe(xC6%nEa~Q%S1~W;G@hhJA>u&`Mo&(sR(rYcR1SaOMF&mF{GR^)rUD_8^Kq`T8PDd7M_FqUGKMT64?d@3 znwl5~HZ`MZOEuI^%F@;MF5*%98`>|wA3SD-Al5$&d0K^Z<}i*wPzx<~{BIW#@$}{H zGHx6nRzb7I{X^otrTxhJ&)YH;oxz0+m*1O5MG@<%-^gp^ZRK$~J|PrxHqS`;mxEMc z^)9+Q^8~r6*#(uA8LV&l`bNht!j#TWp&puqVYB@4Bi4`l>h`LR3d&wKMo^3*Z4Rci#k}&W5R_rVX}0; zIUasN*Qa%|Ug=zwPcsjcL1XrxsH^+}i>4H6o@7sSN}OOmfvY3uFK%fy;%E3mrr;GB zdT}dejw4}jpsL>Y3WcCX8k8*qs2CIyGa-toArVo!!8m{o z3SwkmWD!~H08MvSvx$*NTu z;g*FL75b|7Z_aPTR`vE}5q!J~s}^?G0?|3>^8LAT@YZ12^|4hDWU7G{E1Dqcr%Z*P zFcB~MD zXIL*NdtoFe@-{(Vsy#cRNdjSg-bFui3xqrjqYXrbmT>pe1GK;CX87!Z8MUOh4fMi7 z6#;x*@IN7S?-+CYKge;XRl0Pz+705?8_RE9wBuHNS=#l=5zanuCLN6LqsG9M%BGQZ z5aDJccipcIJ#SOxK1DikZ>JTLp5DR;t9b0p@id$j-OGx!K2c$pU{+<` zei8eAr)K*i?Xl52C9MxNdzipFPz?-}k zY702mFif7F`5AddP}JKK|3w7h)v$H$JfpBo_ShJF5t7&3r~H$Ja3r@_sj3cuw6DA9 zJ1vzUHA)gu%Tfzq`s$??Pe)58cCuaRs`F1S;C-j5K|d(y6$&)ZI0EtNQM2id+bAQv zjkJ;8=K^vXj(RJZ#BY1$D|wf|=Rap?*QqdgeYlN2ZF&T*_!3}qeFdJ(iAEQaSK}iN z+W6LVBx6`&L6_(pHu;d94rig~zlSMHgXzo!m9fn{w8O_$HbvdBlSyNZTb6y~p zT^M=-kD>;Yb=$5$;{91lI+6=Ikts@Fy?1cC{wR69z6>-Hj9WIX`;hT}%2y)T!JW5u z+U_oeH&dqaLCG!)ZR_#&Q$DpbyA$2?^8;Ke^ z^PBay7@|u_?X1DRamd(wL(aERa|VQcg|858cbODi^P8Zj>dz?#d(YF$-ZU#7^CX<^ zZeMnzMTiquU1I)Z<=utTpe;iKkrW^UnIPu#cj7`%j2i$&m#lqk@P|^ zOO5i7$f`L$ug=Z5nwG`Q$+(AszPCrqiXj!~H>cKIhk_+z>vR zkNAPK+WZ?WDb}N2KAWZIs!Qqi`UAMM?50N^d7NuF=CIaZj81;JO&;l=glu={lTR$( zL;G)`xdz~{x-_~DuhZu)Qxtru9RA|eH{cHUPd84 zfRE-GR;%k9JaKVCPWDMS197c6lZCPj+}(E#?L+k?U2<26DGFXXNCwMn{?Z*a_mB*+ zzfIs7!h_5G`4lc3Kt@HY>H6YfL_Qm(=a*~YXG<(~%`z9XzP{2bs26xj)^xGa5gwQz zCZwV)d>9L}_fqBEHaK%u&5r3mg_VX?Z1ME?U$?knK8*DY-4IIdIaY;L>8Ywl-4vYd zLa5qG#aL23t%6D=4v&1ye!@$|HMtjB`I~e-A04{N-m%r^0=%3SlpAo}?dPN!^#b%p zd*}@APpFtHrR(?BaQ>Z+4PQ!&J&E*8_xBFv;TY3uRm=L5`17NsssP7>xOZnQWphB6 N2@G4MTx_#u{sckZ-xdG> delta 1719 zcmX|=FZ%Ts}s(r%@ctUTNxBoQ++_csHpQSn$ri!Pw3z@mVGB2_pdT1l!1 zU6hxAgg|)&L6k=zl4Js+0TB@jg-RbrWZ8<}bPE3k%456~IhJuLO}3ZrYK@q>Y-c2To)`eb)(AExJpp-iABym}58q@p zq6POSU>Ypd7V#Tkg~e`7*ScDWzIU4#3)2HONwpk16kIq4Qr@nBpa-`! z6L$OIqwb&ZCySzBWNIHnM~K;B2Hpn-$qA^mM`SK15l-)QC%eSi5V!t}W@r3ia0tJS zPp+rHbjOmYERp`r#V;nAVE1^K^gKcazrTeoN>Fw}1lW9+Q9CcRTuqJEAw0-u|PizIPi$CDAS&q<-S&>IS5i>T1lB2<{~L`4Pv1mWN9 zp^u$^PHY}ZWcxqxJ=Hx8(oTu9-}I>jxUn^ckvw#UxgYYF<(C{;K+5NEEd1|oYi=Kl znSz2OXKnpt10BBfuf^EH)t{pBKCYmv-w#$ z#1sswGmU`|IA_D~?#!UD%LxqMbA}xd&BIIQ=?Tdcca^bl9Nzi8pV(w-fT44qWUuyH z_@c?1oZom3nI;N}3nDpM{oi|Jp{qGH7Hy5&uC-wcMplG0Hle_8y8Lc_XNOvjtMN^e z4)S~MA)&EEN3t*E`?*&YrMhTT#h*+=sdXFiNQVIQWhkck#ZiQ2FN+C#$VHOgI(v1} zNozEDB2gQnz6;$}mVR-wKeGiG7SI)h=KtgBH9$4W0nZA}EsIlZUeVIz zSwwO4;?6{nX^ae`u=h}VwgjBes&znt|n4BL<4Ny zb&gU557X1H#wB#@rY-bttjhoN@Mi7T9~Z-pTTFZTfl5~3*0}gcKRN$m!=-^;3Nm5Q zM0?HN&E&DX?uD#D_M9N2A!|7}_D@Rpp>z|^SojpW7-HQTac*N4;0n}SkLWE01Z-F`~f*5aI3D<3( zGkMXO)K`6rG^;QEpDmDjE1ZhZk4O6j?!0|yRm10M(GUMbGV3s5>4@K-0-s-LU;1!h znBT#zHgI-lCy_NijLwcZllage$~p9i6flp$9F&CpS9$Q8y{xl4L*RcpLc!?51r&y5 zDJ5p1$yAi+gy!am8#R=oVp5VP!P59xdu`be`-?vnG^A?%Q=W;{${7V!-g-gzGCqTf zm{?$P#xkhLH{a^u;4M^j(h}Xhb}99*g{IDybSwWg8GTH?(2f;Y@j|i%)bY2S2vemA z<~9?C>( Pt?Bc-gS6{CoasLR6e#%s diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors index 5f33535d5ad316f4cef55925a215428643ed2f52..a9f61b36ed2735850072768de10a153ef1aea7aa 100644 GIT binary patch delta 9 QcmZ>9nc%?kuA##o01%i1k^lez delta 9 QcmZ>9nc%>3p0ULq01lS|@Bjb+ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors index ade6a9e03118ac410f6f0ff9357b6779074437b6..a9f4608f6c1c544e8324803f154f57cb9f6acd96 100644 GIT binary patch delta 173 zcmV;e08;h_}JR*~kx<3gn zoKLrTs;WFyv)Q_T0RcFZS-c+sB9nQ%F9A%Gsk~niIFjc#(xDPP0|sZgbL}@g8j~5l z9|3BUIlWI30y9E5VzClEsMYT|bOt~?3VxHCy*?4GIj%Q@a}qt!tyH;GpgcShlj*%r b0gsb8z8?W&lUcq$0dSLlVQ3i2|LL%w>h_}JS&rtx<3gr zoKLrTs;WFIv)Q_T0RcCYS-c+sAd`8#F9A)Hsk~niIg;l$(xDPP0|sZgbL}@g9+Mfp z9|3NYIlWI3{4zo~VzClEsMYT|bOt~?|9z91y*?4RIj%Q@a}qt!tyH;GpgcSalj*%r b0g;nAz8?W%lUcq$0d$j@zFz?!lli{q5xGhY From 4d7d41cdee1e2406746ff38739fda2c58586e811 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 16 May 2024 15:43:25 +0100 Subject: [PATCH 4/4] Fix act action queue (#185) --- lerobot/common/policies/act/modeling_act.py | 6 +++--- .../aloha_act/actions.safetensors | Bin 5104 -> 5104 bytes .../aloha_act/grad_stats.safetensors | Bin 31688 -> 31688 bytes 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 4a8df1cee..3aab03cf8 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -98,13 +98,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) if len(self._action_queue) == 0: - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - actions = self.model(batch)[0][: self.config.n_action_steps] + actions = self.model(batch)[0][:, : self.config.n_action_steps] # TODO(rcadene): make _forward return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors index 7e7ad8e1df015d0ff52d689b317b8d77b3f380fa..3c9447d7fa0b68143216f21c4d9cf5c075253fe4 100644 GIT binary patch literal 5104 zcmb8y`CpCc8wc=ECJFH+l}a6H;gptB=P1v0x19DxyO6P*bCPsUY15ExgprtL#?n}# zO<7tjW9_<7N>MWzW*AG85mBVbl-BPt|AFVX=Xt%}ulw~}_kDd{_Z2--<-cb=&zk4B zo+nNECSJyKO=;fc^7K`{cf+c*Oba~S9rEyVFiL#_bC{;ci4RDXtiTL;m9 zC3E;VeSuRc&AN7 z{j+4uYd#-ORt><@(V>`-UkWFTJZRT-1sQLXv-8$wC>z@Vm3PKN-14d5{K*_&ymF!H zQYqVXY!-s*Ak@4LMxoVND3E(nvTQXeCdrw8=UANftOle%jD(nNqv1%)Y&49qBfBwC zjCVi>otO&C_6foHukzrxSRcB$HI?SQNo3MLXQB^R3(cYrT>4=(*j&fQ#@BNx^0thP z+%**)T}H65zCdhsI0^?+{U~F7JS8+GG0m0snAuzdnvOl(?_ZAxD=i_$@AjeGUK#UA z(nXaED$Ft`9Oqstf|_Vw>bkmy94{v@@lrA3jVcg7?cxM~O@v*KgsAQ5OZ6|J*!?se zymnrN4gL{|=Pnii&(D`ei{t6xp9w5Y>WeN{Zo=N5KXMN<^q@FlHjX;$PCh~jyI-e^ z5y=DKb0`cWZyg2ASMJmrxQep>Q81J9({Z8wWyo$G;=Fh9pf%GTvo?Cs%Lh{S+{h4h zXQ;5}b3;+II3KK%y(m|rpgl2h%+JyhPv+FX;UiyxMUxTO|89wo{$k{nEM>nn&cc9` zeK4&l1jFM`!wGdynrIh8o1^2Hc(w`l7vF)LRyEl0SQ|#Xw89Q!Px`e-!j=`7qyAU@ zaQok29KE*;*135R$D&AkY#f_+Wi0-9`zBIRo=v6}VzE|AaUv#0})EEPwIns}LQg(a2A*LSeg{|TR_`azSB7@!N z>u2%gTo%hRt9f|EunHyztAk-D;6B56>BR{O|zWi3`W^1w~M&??Klr zlW1>mEaShMi^~lg!APbBfrXPHQ^Lc4Ivwcbex;M!;M#&7Sn3jvy_e2_eZL!hyC|N* zz2$6PJ`cC5-Gkao6X52(F(6hoLGxdL4$YP^hjG@Ja;^td48t(n_ZKMUxs!iIEUi$J zv!tCGc*?W^DtbqP^hF2PwAB=|c-9oVB8ug78(^PhKl~9Ch9Sv?U@_B+R@$wl(%2+6 zu-gL1uKpFQp_^N`q=T!QXofz8LYiw4#f;V&;D>R8kT4C9U8+{@L&|it zJA4h2QZzsuH5CFnXJfj*6Aj;xu@ZkXoTjP5B9;f^qo6|YKjcB?W@{*Ohnx)*i!gKe z4h)^01i$%Bh0e4Y_|Kp71QrSu@sXlXE*mMYc{nJ zE`68;ho)(O@YYN;;W<*>BPk1fXoD-0`{B;tfmlEBEC}zo)7ERzq|1{tzZ5k*`9~f2 zG^oNx|1NG{y*Ww)EGc$R6w_8W#_x@X!9*(veeUGLDCR{KO%(Aa&nL)vx@JmF!fR$TtD2; zrFDPeI`ssYG{v2+#YeGa;l|j$Wf)R7g`(807`9Y;Q~w+VwbduEdPhLvpEV%J>*4ek zj)arP1;`KapigNl*|H->xGZuQM0vre`mO*jHh7b3WGwlOOJpq<-SOL$Tkr(>IC0)s zSbTRD^5Z!=dvOJGE1iaQ)dP@mBN*qcJqDXg80{%tP3yQgmYVUZqPZO5;2d=jY|??H zn)Y~3&xNu)q|Cw42($H7Sg>OdcHPg1jLoj}r&b(U*2J-j{s|cIst&};j8I&h3X==@ z_}BMB3Og@hu2(FvKD-Y!W(J{1R07HKT_|%Pmd@UYW8WR^udvbo6(ar~56!I_a7ku` zfch-oF|FqrO2Y8zIOp(Kt)xYku9x7>n=J0l^Z=mR&P?A7+W zHuN|~%DNBeVr2FpcrFXVShYM@YUWN=_fkl2dkWK8{l4O+pbk`ae&k{rJ2|MWw=jJ}HNIVku);Xal(Tk2} zNLk!NU5r~j1h))Bv8?bo>`?QhNM+wTijtYwKo>i{y9v$%-P|z;b+CNzgj!dR|0x(ZZ_9^gogNg_C8t1#WcJ3x7;jvwhHhamR~s}DzWaG5dPLjPR+)r3 zm1v{$oqpKg7>cdaioy1vn9AE$QGutNWuD-nX>SE+r=F-@HDSO|j zhu*V>!E9>?)|@^G< zz^y14({^5jnL{3w)+M8M<$9-EG;nP|4J3Jtgpt#pbEi%CI6lCFcC3)GmP=Ex)_D-# z90|snWyP>W%ZGk_yoSb2PiFU{XJU+1Ef^_2a9_`R$+{QUK-1rs@(xB*WKRO~d*p(*=2SuLwr(!iO}br3wdseVgdIjWB=$OsQfzw z-`AgjmT+Ivez%ewOcR-(t2dhMx(P`edbr~oG~j}!Io_&tra!eM%%pHChWYk^@sm(o z@bnl!x|o6wt)_BW95Z=lhG`L3LE1LRbsm}oFI{ahOC~1uT~gK-sgKSxhu}Mt5S&q- z4-pDa%0mV1WAV%+nc-HKS}?c#3MLgyg4`cw;df7+$jesB^y+7!x6LQ;S{{NGMW^6} zwFgbEm6Kq5JUhJH9L>hpL&k+spmtFmR;}S(rg6%9&$gBscvW`r4*0^=Nh$exQU3g`HdmB~QmBIjAIwu#T*FCA!HJ0r4*Rsr} z|KXVhH=yM}AJ@D<9rP|aV#-4?DSnl*H~&t@5n3utF-bWG+%E=*^Q5+>STcQ?#1115 z{rV6x$9?3uNOefpb3&_I9u!+8WAROfSY@cfCVK|sl~KoG9(Ynh@M^mKDv9wMEil~Z z8g%n|xs@v>L2HB&o*#CgampMFbuz|niv~fwFbt=iD}Y8nF$r!a(&6uwbJc4b>={`B zCYw}Y->0##T{H{3KRMA}K@=O%HAS5z%KJ-K2o{zXfxU(&9rjHj*8w@J(DBAoWwp?) zqY0DM#zO8hV8v_}86K;!34Ey1oECgSSI0s`D%Kv_#px?9POe)dF4@;_Hp6LLm zS^I#SanB47Oq@-Q%c7XgkP*H;JqUNxL-0q;gO5+V$;xv*J$RbL{FHrKsk|3Rx4-5p zrat9zw9T<;qb>cU8^wO8fM zTtlNR+Prn6#9b>H&(Rq7st-fK^e~(!EQHc5AL?*hLznj@va~-U@UKlbfzLX)2RS35 zyI>}|EpVjAtrFJ8>7(?_0Gu=o#t5r?z_H3aW^&R>S1|qVaky*Z74UmC3LLJg!zR zGLEiCN}0E>C0dpCLv!)FDQd4+^Slz?DEhd**Y%w01wM*2 zt;sY-#w}0y6ZvhLCG}LHHD>J8jHMuYVh-GV-xYy7by<|xd(>_xpFVrh|~GIz@iFu%GAybiqM22MZY9;pe@t;Ulw z_R5&sCO!OP=`cL*3qgyyMNl}#i`JOPsrUONc3#;J?$I|uaIu}cul|~w5Glmzi@ZrA vBZ{rh(8bkZ!(cEz7|*`WhvcW8q<3{SHIyk`wgdjHS_PY}-f@d24RZen2?GM@ literal 5104 zcmb7{`CpCc8^=#k6d^5WG3{DT+D>&^p6k}4PDIXGT~etenfW$oNElI+$kM)ikNFQgzdg_E_5S5^-`D&3T(PRW|DFwIj%J=4%p{7h zlcZ*zW&tkFW_+`h*p;8&mI(Pyf@OR&$x6k_<%uydDN;p>@~b0XAP_Fw#5Z<vEyL zg&uC7Sm)1Wl^oy6MIiVDyL>ME3FJGudpLih1)ohTLB6xI!0{99`nmKcknij!{4X-M z&!&|i-^I!0ztG%2m;MCuU0j9#r6&ArS_$$6jxPVD=JC1oC-BovQKq)Z#63J7uXsy5GJxp@I&J{)!?{s0-R}@1}BCr@!DxulB-DB zxF`n{W%h#S{0J-@J_F4wW)phI>CTEo)>37QN5r-8V5S!AIjaStM>BA3D@R#}B&^8P z9@`UoL4x6^@k-p~o?pyz68XT9d|de7!JFb^|8ayyw=uRfn%E-7woAkOtmMSy7-d`ZW!L z!=#0HptAr}bLP@zT^UubRIrjM5g1oq4xzR} zuF%beqO7^pwk4ig(iM#Fu^87LssK^O2d)UUA%3eF+Gcu@`rT;e*lvP~kU==3DaP_I zb76m-9|=w;k++td83`@XTS4IVZVbHppbZvJ_~`nEkuONuRW&QTXvt$Y8p3gVd;$1Y z1dzq}R4O~cc>gaH`78iv2QmcaT~vq^eg zPC7|)<~^u^fi^YpL{k;cZEffF2istCo)g)BCuROiO>w!}2;6=WhQ1lOkku%n_>zq@ zusM|-ZZ*b9EIzuzsv7p1c@D=Obd6wa^5Y>Wsj>tr6JXnFsr;gDB&NHRP$I zV29raU_eV1H0|x?ROe}e^EN9qZFZw!el)Y5Ivu0Z`=LQijFwutu%y<9I^APvr%cYq zj(byT=yVb4#bbbHG#Lcr9I+|Ki*~J*FxPozc=7HK47Y~k&%fk@pxB>GQsOD_p7I=g zC!>o`1rTB0Eq*3eC&a(o5OM0*>iAsyDvqWt|Hx)a&|el zv-Ee5D!BJR zma5v6fv86nV^kjfO3RB$My z3@UT_xDSUua&wJ@_na)p=TBLp>uXQCQxU~1zA(lQYx>~Iui?m5<^iwUmyY|yQKM=iGb%DfpGgFn zd(}bkzbVl8hXWq8VI=jIFr)LPxEO{&YkU|k)-Hssf!X9|E~oVUiA)wRz=L(Q@N@G7 zxKgbLp>8v9&esC!d>PFe7dv7_aSzNBh2fvk#ZWu!M`vfqsJ14N)p>q_Uh+GTwr~Q> zPS$`u|5)SWPH#$E9?kNO&BP5IJ&>|76h|dqgge2rY1-yh^e#1#`RA))AJ#y|X%)zm zzTs#bA2re)X>_=RL5KmGWDWp-ZYbWKQ2+vG<@~uOQ{XRYOnh+~zI}KT2A_9vi_<%} zukY~j&nIrQW0{1tXBePDHUz7i=Ha}tIbgt_Lm8!U)VyOYYf`#=cVG?hFZFQKTZg!k zUGCU;JAig%Nm%3&L%fhT2n$z5pkMtd=u`T7+w*uTcTQ#BuLUe~y8*K1E-vjKRp|Z8 z9nDQdRChwkTDBVCiQ++cksgLm#U~*AUJwPHiKAb?RNj|#e-ufpV5!eXZu6z_aHGu> zLwr05XG$4=w*~TAhrs@4rFZ|$hKLa#DqX6eX}1$t(>7zY_^}iUs?;DbbPSm8wZrCc zpo@oMSpF9_*!_aXoF0Z_ta%>X)0j=FS!-#kwVXB8x}r_=P1xwI1=Z`+z(!z>+ACZs z*eiy02Dszm#e<-l5su$%J_}7UKdQN#Of6k+V8tG%16*k_3kl7LLk zNLlnWSG=YCKZq`eVOV$(EPCos@1G>n^Ib}=YYdhit%7rPL)@%A_qlazt+7ewME5Vn zu*)CJu~U=Bw)uqP=2OREnQb8b(72v1pGsvVcc!ED)f(`B)yWxo-sjGpvcbeYcRKic z44Zt!9KV~a!fsWCV8yAUV3!(1*Z*BZe;!O@l}BdcP;nKA_rBvYeA>9PRRRPZ5j}qt z!~S?^fg9KGSaawC^wli@FHaFw-&sSAjtZt2Np-o2x>}>4ftQ=H;3e8Z|?Lm!qqgeSj z(~u?i!oa%)xIZHw@-=;F{-SuAaW;XS`c@zP`>#RS5e@jJQ4f+f@p0pHAdjUI#u=L6 z_Td3AeAL`@cKMUsG?ttmCoswJQE0fL76cxnA#(dWZpUnE6h3gE{xk{eoN0(NNBSYw zU5raQ@*s0TAcY=Ep|W-b^UpTKr>Ad0$v`is`>~6AaLyXdF1XPmg@n~=8{+xuL3naK z42_4fA!Bn8Z3$mZNA9FD-7Z^v5LgR}mqT3T3LYHsa={Nr11P^i!u(Z?@Ib);e6?&5 zjvhG);xZ9A?~kWaixg&I%Hj5hRUk6$;o2sThL$!Ld@@c%_e!N~BN?FjkpW0QyZ{d! zDS+myB0Bvvj=W_lOuWk%mk(FNy3BrVS)vw%eYC_vPj7F$qO zwlCKYA8A*CHID~YDTXk@S)=yP-Zc9AD0Xwlbkx!ufr*I`SQ303_Pv}%yKc#-WxkwM z4v)rB76ifi7|t`g8kq ztXy3SXH;|{wL>5NxM+iCg54NNzxbXZS{B$Q2SHCQRRC8rdJ{L#12jpz(qQ|BE z0o7nsrUw7B>f(NlQ2H&?nI`)y=g89-_m_=8t!5amRvd@j`AW~br_n{1G-h#L4Ud%8 zz~EnfT$!?)TY5U+vLl^fxsP#()T5RM0mbKsm^An8=d=-@WxeTkllV_#LmVC4vx z@{Kn9)NO{7CwkL{vM5$xHy!VK4?>}`?{2u84O`lMD1KKW1wWFpmIf7Ed>CQU2~Egr z(+9qxJw}h2MYBW_W_`mPFI0@czYb!wzMc!MfBRFax`L8i5?E=wA?{pP3-{_LgV_mP zhzPL2^;g~KSVc6eb9cs`!hYcCiE;9lvoOEGm*#FuBHb9J2S<&;XFuNtg~?>lzdaV_ zY_P(gUjeQ5ie_fn4tVBXKQu3kzU3 zfPE0XA}Q+)H^Slz!=RC(?5&S-A#r()FJ7?3*j@6`*kia@d+2b`$gm0By=~S#f#C6zVRlN_% zGNp`H`~{{i?S~&`EX3`%3n3y#L{FTPNN19qDW=ax)iXE2`IrW5I;jhSi{@zG>`o@i zy>aSGd)zzeBRD23K<|AeQ1D{_<@w5JaE@}{47bJhtG8hK6%B|IjDbd5D;y_eq}-pR_6sSyUCLYp zQ*gI>KU669#v)Z^|J0jLSNF!!gh{#=lj<<)iIfRfQ(}A0S}~5&!@I diff --git a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors index 5188d8f428cdddf509da7f91f29950d9c80ad309..7dfbc3b35cc394e87cd77018d4f427928b9aa637 100644 GIT binary patch delta 110 zcmV-!0FnR5_W{WF0k9BvQS|F*yWiJSJZkQmIefA!Jk{F(xgs{;JMQ7>IWSb)JK|an zIg=~7I~5V|xpws>JPm4_xc^&QI}Rt-xeTT(Je8;;IfRu%JHo>fxUd`sI~9{*cNYn| QUje!FRLwdIWJV(JKkCj zIg=~7I}{P{xpws>JPvA`xc^&QI}az;xeTT(Je8;;IfRu%JHo>fxUd`sI}(#&cNYn` QUje!FRLwd=vw?R<1in}@TmS$7