forked from tangger/lerobot
Merge (No verify)
This commit is contained in:
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# keys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import packaging.version
|
||||
|
||||
V2_MESSAGE = """
|
||||
|
||||
@@ -72,8 +72,9 @@ from lerobot.common.datasets.utils import (
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
decode_video_frames_torchvision,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
get_safe_default_codec,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
@@ -532,8 +533,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -543,7 +544,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else "pyav"
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.delta_indices = None
|
||||
|
||||
# Unused attributes
|
||||
@@ -762,9 +763,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames_torchvision(
|
||||
video_path, shifted_query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
frames = decode_video_frames(video_path, shifted_query_ts, self.tolerance_s, self.video_backend)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
return item
|
||||
@@ -1180,7 +1179,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
return obj
|
||||
|
||||
|
||||
@@ -1205,7 +1204,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
https://drive.google.com/file/d/1_SOJkgfP5yZyVjMhTt3nwhvyUjcnlI51/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1rmgN8UUzph1qwJnzG1d-uOafodn-gLvb/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1NYQ-XxsBVinB6dUoZmVWweT83367P3i2/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1oAv_j74zxxCJieMG7r5Vl2BeHK1__3s3/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1wFUJQROsrTJt64YRuIeExhFjr2wnK5uu/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1KzL3Tt0Le7jVl58XVRUcmigmXjyiuhbK/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1qy_YBladeHtianSSGtgAPSHtMin7msvf/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1rA_F0V_qL_nyuC_0aBKCisF4-0TIkF2Y/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1hw-8qMpz9VgSt62XoASqNRuPECpCwJQP/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1BpHOl9rKMzdvNGka6js7C0s40hH6vnDA/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1PazhkhiDnJ-OUMyDVDFxEZNKQQqHiNWS/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1lZ665R6ATl57dypxH4dGJ2NSt6XYnbuz/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1V9HzLaf-tlG15wUzT7KrTDCS_z1vi5NV/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1aKauWiXoKqbNwn_2xs4MrmLlaNYlVNmO/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1WVD5DFhriO1YmmOgiVHhacR6HWoTPxav/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1_X43WgeBAsfkhH9EmpyPki8U9joMeAGC/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1t8x0GqWoNKWtnBsB7_D40Z34nL9ak4kf/view?usp=drive_link
|
||||
https://drive.google.com/file/d/15V_f26WaKOXjKnq2T3HRWAmtQUi4lbu2/view?usp=drive_link
|
||||
https://drive.google.com/file/d/11VFIAsiSDsMOBANgrOcZBpKB9AFWnLy7/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1M0NS7vVaxJv3FHnuRYtdwTFYF7We4LxP/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1mR0OItTNqFnVLoczcyKYlm6drAy778lO/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1NbVFWDQAh-z4JJ4D-Zw6Lps9kdvpqh2j/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1JQoZGBzl4W3QG26-n39tefcGN0fDRMbB/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1VBjHl-TvZpncopvasIP5G9gecbB2a5f6/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1VzSf6zaB21nahm7MsPwroXbJ84NIwq0b/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1OtNnfMEydNtZOcivs4k6E_uJSpf8PkGy/view?usp=drive_link
|
||||
https://drive.google.com/file/d/14nVvpvsrFr_03Pa_N7MKzwnRwibOUYM6/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1M8li6duiO2r3lv_9HhF_XJn0oZUIEK5F/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Cpzea6fO14lxAaNfSBifqoa4ekhCiLD1/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1mbxRTm5vlbsY9UJ0jfjM6j9D7kPJjBpG/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1RXD1i6IfWsHRlCxVmG04h2h5Ycm_WwZN/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1QFqFSwDGOk1BkgGmqgCcc2BRWnJ6R3MA/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1bFqWR8DQM0ZUxxtS2bl-RANQvukeFLzp/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1pR-rH3yNGoyPdD4hJ6-3lXQ-PstBx9du/view?usp=drive_link
|
||||
https://drive.google.com/file/d/107OAwLY-hva9HeQLIK7VCh-ytdDabVjr/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Tpl08QOaSZ37GTO4awFWSdD8wBR9xdlT/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1MR164AOM-0S1T6RX8xKTV2IHyaCvpqAW/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1_wknJfVnStIhJ82lU_QtcrwahsqYIsr8/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ZuEktWrbYkTx0l5pj3WiZ2CJrfbDOHNo/view?usp=drive_link
|
||||
https://drive.google.com/file/d/15G_10hkkkq6yxvyI5NGZirlF-RzduR2F/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1DBKxg3ONqh7dhLuX6oh1Yyo2x383V1Hp/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1B5iDBkTUr5vopDddV_fHud18SqAHhauS/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1acwFV0eenRkki1QcjSKH5xqOtys-P3Pr/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1S47BI83xyrh-FKXsvAQqer98Biu_p8XK/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1JL6DmBZl3uyq9dyLfgSqtGF06e7E9JwM/view?usp=drive_link
|
||||
https://drive.google.com/file/d/16WvRS4Kjog8Pxgr0E3sGGnI01YwL9Uql/view?usp=drive_link
|
||||
https://drive.google.com/file/d/12ttGqL33IPWg0-s1SD44rr22M6LiSQBr/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1OyZqqnldTU_DliRbr6x0C4a_iWPwIN7j/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1oYk00IpLnR9fesLfD15Ebe7nVBffEbcS/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1eyE2-MQduCEqCd-5_kl5zsoOEERAzpZD/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ir1Ya-vO0d97pfvbePlUeuKTTRc0qIMU/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1hOi-JnqlMt47gVnLZHMTqeojyYVErohl/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1NFFw5_PqigQ7xGqsL-MNq2B1r5yAscCf/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1uftq1-Zlh8d2sNLWrlVcKYQUwZTD7o24/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1-ax19dSLPacVgk000T-m3l4flPcg07pM/view?usp=drive_link
|
||||
https://drive.google.com/file/d/126y-lgn86-ZmCz8hooF1THKJGGObw3OB/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1JiDniK0VmDIkk92AbBILb8J2Ba59PWML/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1kr8nPIRljiU0R4J9SMgj80o1FPQxzu9z/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1bbThWRij1pKBh_kFgV8FwK0sXtTHBoLX/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1WenzDW6lxk1xkOFm-OiGFfc0ROskAuKU/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1MiKRzuzUn1yN-k_6kPJJzIGy7dT-nnsD/view?usp=drive_link
|
||||
https://drive.google.com/file/d/17rRg2tcmB-gNhQ0KoZJQmNfyFeoij1jH/view?usp=drive_link
|
||||
https://drive.google.com/file/d/11mokBpvrY3ld6sY5WztREtJ1jgqfQV70/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Il_6IOx9NDp1bX_KHizJfBwzTufTmn86/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1KswtJGsxJ7eeBDAmNA_aeLjOxcH6MIxa/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1gzMhi5uWu4C3Y6WbQ3L-08V96GxTZrRR/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1nRQFtaBxfUCYc2W90Qibh0kHCt6YQCfc/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1vs-gyW-KheqHbUATwAhA2mmR9GOGw7f_/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1MuxzGOA2fgLaHryq82KkQumtuRJGcUOC/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1IIwxZnGlqrXLUXqG6yMO0r7uhCvhpk9e/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1vE7XPyaFcXP4DtTY5Y9WKIt7zWgmX-Cr/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1j-bIV09gr21RC3-x1N_pK4RPLV3fmWKz/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1t3nW1rD3S-EL0Oymb5U7ZAj5UMkydkln/view?usp=drive_link
|
||||
https://drive.google.com/file/d/14hbfHCdMKtJZ41F9CQReMec2jeRFTOqR/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1x-hUyOSne5BW0AzQ3W6_Pf4g5yXQWi9M/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1sw9JqRg6E-3P84I3ZhzTrJMu0vuiaMmP/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1LuqhQlL4MGZhB_6THmkovRxrlP26BbdC/view?usp=drive_link
|
||||
https://drive.google.com/file/d/15C5K6v_lkjnMSmUvVyqHQKwh2N166e7K/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ns_9eSsQeeoZ10nlbkLy8tu0GmJFSnkt/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1NpzWJeK6CqjxzjIMYe6aYdX8xGsQwD4o/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1NMLezwufKJ9_8xTc9KQThSzVVD71B9Ui/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1aa71DCUqs6oXlIxX35jgsmsgm-NlDxPV/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1UJzkIZzAL0j-D5YQBnoq7mHvttASy12O/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1nPgx36HIJFb7oI94VbRzWjpPP2GANxzG/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1NovAP-KVJjqcuvWy3d6G4ptGGAIDqcCx/view?usp=drive_link
|
||||
@@ -1,55 +0,0 @@
|
||||
https://drive.google.com/file/d/11M3Ye0r5agMaaicPbVGD0q2Hb3rGklbb/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1-tx7SvYYgSvXCvnf_EI2OVdwK-CkFY6S/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1EWJunmOpMHaU1hE106wwpbkGYcjQXYAF/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1IDn95Z7FSiCckrSENtGV4u3RyFHNQSDY/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1CwzvWj1i7QOtqrZvsCZ6BdZaKNDfpN32/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1HvAvlhm77nAD3Td24QPSeq8lw-Rl_aOh/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1t-suKYOPhXH666RpAYNRp2QU_DOy3AeM/view?usp=drive_link
|
||||
https://drive.google.com/file/d/18xpKgWh7RWyjMN5PkLTOo-AxsAadAuRw/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1oci5Eto-ztv-AQNz8EnwZveBIhxvk-xJ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Y-t_4vxdE6NpHO0DLJR8f3mD0Q-Wj5-c/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1lylRqbbbB8bgtpsBWMPACmHJreuKmllv/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1yliSyMig_NXShWfQx6qyW7Ijf2Y5lFK6/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1XXhwJsJbeb7KXAooGvJapnm9bjnGUmxS/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1_xs1f3hW2JArKyvfF7UWubWjyROGTLs6/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1WVEHpr6EqKCZbkHapQSTXJq4xE4SWFT-/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1RqOHv9pEQGvW8NUA7ynffFmG999TL_Az/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1cu5AgD2gh-uA3PFJmzxxzNaF3qOSlYY1/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1SsrXqiPclNrnYToPZ9Uq-k3y0C4qdHT1/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1-J7EXf0vjkLIfSqT8ICEsP6CTjzSLBop/view?usp=drive_link
|
||||
https://drive.google.com/file/d/11O7ewUmoZXfyyKjy_6B5RW4DpjICxqBT/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1iic44kZoCsjNsfAz2cMstZ9-WQvAhblF/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1yLV1lVX-2WnWQldGlnQZ0x7QBuDiVkL3/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Tybp9ru98TTbGn4eyROpUQwDFuALWXmk/view?usp=drive_link
|
||||
https://drive.google.com/file/d/13E9OTMiipVJByDs5-J19oWwAz7l94LTN/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1EeTpJQdMSliw4JzSMtJ6CyTvVdexjM4M/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1NHyNwoFqzeAu-1_PSpq5JfxaiD_xbpn9/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1fJcS0phDp4xm_FyGaJ5wr9Pe4KqtHaxD/view?usp=drive_link
|
||||
https://drive.google.com/file/d/12AqrLUaewDPEcFRqPZeZFb_TQ0Lfi3At/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1x_hd4Qsq1oJS-aj2t3qM7WbbV7KZj05b/view?usp=drive_link
|
||||
https://drive.google.com/file/d/14OUSUArmsB068hs6BuEIXQhI1Cyz8Sf0/view?usp=drive_link
|
||||
https://drive.google.com/file/d/16zlzh1T5zeUJQnFf382NXkFEKEnDub4O/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1IbDltmN-NEFCNtr1TO4ILxEgQ94rtjWv/view?usp=drive_link
|
||||
https://drive.google.com/file/d/15gmlf8Gx9455pZ1AlqcCSwh3nDPxMzSr/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1qHpRL1oZfIMo_vxnm8qfwQ-7l0BZIVva/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1H1xskIgiFZivkYn23rMzH3xePGOh3VTC/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1avls6Pv0kYiCMNVknbc1zQsgy64MUDMM/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1MmWVgCj5khc8KMIifmt3EzF1o-CtPyyn/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1U0kCc_xqW0WNppf4sbnK14euWKdPZtzB/view?usp=drive_link
|
||||
https://drive.google.com/file/d/16CaEyQscOuhLj23PEGDTL9DeyNkohkMn/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Iu8uM6UUJ0zW8tvN-9UiOe_4oSNzEutg/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1UImqiBaIxCR-1DNJaZhHqeHhaySOtVIr/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1VpU2V_leIoRIyv_lAvE7eLHBG8DxCTnp/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1_Q8J27OT3Xby7QY6yHvIJauFRWEMxkRm/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1bantmVo1L9Xz4tbiNw_a1UC2Z_HPO1wT/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1IRIXMJMCBDkBjbaHvAlEiBogSvZ1jK_3/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1mAHXKjiFbjwydypW2t5Lv8_H5x6nHegl/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1SfyY796fLrBCMY39OcyuxZafqSCRZPZk/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1X-44sZ8CcfzIskc0dvSx882o1yFhHaZB/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1BOIWCCCk6DLD4Bmvc75ZbbLi9AQm-1ao/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1RuyDtRE1kk76sw-wP8vx5SgLoPF3PA_H/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1c4eoQiBbGuy3CTAQDUSkd84Ponh1roAQ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/19PXB9z4Ljq6dsbf9TqcOrrP5SRbw2Tc_/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1nn1VVZVoIXWdYDozR7XHXE4mPLQG80PQ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1MBdFGOKPV8GUhwoSsJ_Ky3qAMLM2Bv3K/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1of3k_M-7Nh3I1TndcWedxK4ca9dn8Sc5/view?usp=drive_link
|
||||
@@ -1,20 +0,0 @@
|
||||
https://drive.google.com/file/d/12ctkOAdkCNGN1JLbZb5ww3XTBn2LFpGI/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1G_Vd46_4fq6O64gHHjUbJX5Ld44ZZx0y/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1uKgUy73B3xBogQAOUhfZjO0X5qZGsi2c/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1fu9cIrfI-fE2LhdGUxbx7-8Ci_PF8Ypm/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Ygk9ZPJzx8xw2A9JF3NHbJ44TqnvSTQR/view?usp=drive_link
|
||||
https://drive.google.com/file/d/18m5xPuccNsEB20WPshm3zhxmXc6k63ED/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1DiqqxC44rriviRQpqogcv0-EB-Y6nr9g/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1qPdaoTVDizJXkfXLioWU7iJ8hqCXSyOQ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Fj9kIA_mG7f67WFfACJEaZ7izcHG7vUm/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1WpYehZnI2P7dUdJPfkE-ij1rqCnjZEbB/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1_zwWkT4jPyzB38STWb6whlzsPzXmfA9r/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1U6-J4I_fPlSFFGfhZPxS5_YzKXwXIZYp/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1pRhxxcTfZp5tQo_EScvJUwfc3amiS6Vk/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1lWLntqra83RlYU_gN7Vostnfydf6gutd/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1vIBKo0x-NYEHV1FvRpco1lQMpRdAWAIL/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1pdrLV3JTQou_XH0Aap61Ssf60iVKm1jJ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1QTsLoQ7SwmKdQHjBGVDaR2uTwfFwtrOf/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Gytai8M_12J36GY6L_TulEcOC-035jwS/view?usp=drive_link
|
||||
https://drive.google.com/file/d/14LJudNc629NT-i8xreXtzl27ce_DxOFJ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1sBvPCODbzxGAI0S3lgN5cSG9Go3lRi00/view?usp=drive_link
|
||||
@@ -1,18 +0,0 @@
|
||||
https://drive.google.com/file/d/1MJn9GbC8p9lN4gC9KDMLEkTkP_gGpXj0/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1-4LXgjl7ZCOgp-8GCJmFRD8OeqN5Jf7-/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Ho06Ce0SPbqU3juaMxNUwAt3zCRLGC8W/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ivHoj7_7olBSxH-Y8kqXEW7ttITK-45j/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1qjY4hM_IvZ8cq2II_n9MeJbvyeuN4oBP/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1rKVhO_f92-7sw13T8hTVrza3B9oAVgoy/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1pcLPHO8fBkc1-CRa88tyQtEueE4xiXNi/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Vev_chCsIeEdvQ8poEYNsOJFGy_QU8kZ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1l5G4zpRkxSLCQjvGPYSN4zfCvVRQuzMz/view?usp=drive_link
|
||||
https://drive.google.com/file/d/14vgthE1eoakXkr2-DRw50E6lAqYOiUuE/view?usp=drive_link
|
||||
https://drive.google.com/file/d/17nPSmKKmgQ2B7zkzWrZYiLM3RBuFod82/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1QcDsxplVvb_ID9BVrihl5FvlC-j7waXi/view?usp=drive_link
|
||||
https://drive.google.com/file/d/18pEejBpI-eEVaWAAjBCyC0vgbX3T1Esj/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1H8eH6_IRODtEFT6WoM77ltR5OoOrqXmI/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1IWlpFRZhoxyG4nS13CWK4leZVk5wbNx4/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1PbZA8_OCGmMLxNP9xbkLRSChniL4uGxl/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1p9XAdmG2f_WeflNO4DIJ_tr1rK6M9B4B/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1nS59Et1cNAvKo3Y4SeSGRuZD5TvBbCF3/view?usp=drive_link
|
||||
@@ -1 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1S8eFg98IaGAIKVZ8QFWG1bx4mHa-O204
|
||||
@@ -1,4 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1tC_g1AJ8lglBLY-fjsQrG6DMBa3Ucp-0
|
||||
https://drive.google.com/file/d/1fG_Yi2MJrFjiUVN3XoiWXLtTxHlwwaDv/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1WX32VWfzzX3Blmd06DRxLwFbMJfVe7P4/view?usp=drive_link
|
||||
https://drive.google.com/file/d/18onsX3vXg3xkFwP5bVUCjdV4n9TRn0C9/view?usp=drive_link
|
||||
@@ -1,3 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF
|
||||
https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link
|
||||
@@ -1,3 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N
|
||||
https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link
|
||||
https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link
|
||||
@@ -1,3 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo
|
||||
https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link
|
||||
@@ -1,3 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj
|
||||
https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link
|
||||
@@ -1,2 +0,0 @@
|
||||
https://drive.google.com/drive/folders/19qS_n7vKgDcPeTMnvDHQ5-n73xEbJz5D
|
||||
https://drive.google.com/file/d/1oC31By0A2bsBeHyUwBdQw1z4ng6yi9Za/view?usp=drive_link
|
||||
@@ -1,2 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1m5rQ6UVH8Q9RQp_6c0CxkQ88-L-ScO7q
|
||||
https://drive.google.com/file/d/1wHz2qcmwcVG0C0CZ9MjQDQcmj4OY9_a3/view?usp=drive_link
|
||||
@@ -1,2 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1seQGay470nGQ-knBI5TjsTr8iL9Qws5q
|
||||
https://drive.google.com/file/d/1T89hSX5U99wLGvGTE7yUBaQPOpyj6Sai/view?usp=drive_link
|
||||
@@ -1,2 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1t3eDc5Rg0DveyRe8oTm6Dia_FYU5mXyf
|
||||
https://drive.google.com/file/d/1TXFaduTakvS0ZWJqKCX-HIvYglum_5CY/view?usp=drive_link
|
||||
@@ -1,2 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1Z9X3DNzd6LS0FFjQemNUMoMA5yk5VQOh
|
||||
https://drive.google.com/file/d/1Wlyc0vTkjXuWB6zbaVOWhEfD7BmPgUV_/view?usp=drive_link
|
||||
@@ -1,53 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1DYgB4ifX4uIid9m9jnC0Zdz8Nf7ZC0fc
|
||||
https://drive.google.com/file/d/1Eb-NRNk_FmVleCbU_Ng5Y4dfcjTKN7Rv/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1dkhjEADakT-44l9jf-nK4x89kr4yG_qb/view?usp=drive_link
|
||||
https://drive.google.com/file/d/14hDhgcZkVqNExGb4tIXpSjMshhqZETch/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1zVMEHpHbuNyP5A_lYU7RPSLB-4V0yfZw/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1JtgDjBvy7FnRpFzrx_foC3quorYQFAR-/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1EHdneB6F-PP0dQlX8qPaXbxmKoBy_YwO/view?usp=drive_link
|
||||
https://drive.google.com/file/d/17Z0jjVBy1OPKREPu77_n_rQzorDiapji/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1F4i23qPJ_qTf5jWjfLo4ARGJChznYWt3/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1kZtXWM3uS0-rLblydBfJ0mMcVnMMXw9w/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1mNODox87xFfY5Z_o5mcLsr8SHb39jDik/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Ob44VdmEUA93FKDECiRb5Ogz2xQg5IWp/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1fdQLdjj3Cwv33R1wZhfrLz9Del8mqgHb/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Yu3L3ft21zP__XL8pCfhb788ZleuW1n5/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ozBBWXVZ9hXDh9ooHUNroHdYm8UDqnhJ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1o0TGqvfWw_Lunxb5ubKDS21Lr_WC0h75/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1jZnd5eP5L6BH5l98BPN6OnoQx3fu8e9n/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1S5sYbz8wcLYp0V67v13i4PRcBxodn4Hg/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1rFeg_x6ftJYwPtBv34D3h2L2cpDLeR4G/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1GvS3lcm4o6nm_scUk0XxKeVFNmzjucDZ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1-9i0riphC7NhhDahcQfD1QoBXP5gF90A/view?usp=drive_link
|
||||
https://drive.google.com/file/d/15p_IqGsMbKuvzMS872THAZr-3SBtb1Fr/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ToyYcBfJL8gbQn0q_59zPLsFmm7dmMJo/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1e_7PNH7CYafE4pAebP7ZdI7XFbmEcy_i/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1JoabvGVsIQdug2xOhUIhetEIyDM91y_Y/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1kOMw1y0lmnVaCjwZICfzCsx6e0Z8MNGR/view?usp=drive_link
|
||||
https://drive.google.com/file/d/16it_wd1JOevUQTK2_CvF_pBACTgpIPgM/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1IRcCj9HnJSfbyMgr5XEERGlEnWeZQwOc/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Z2dIJfq_S3liGmPN9Rphvkmucnmw7tlb/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1J3NoAjzndGx9yNyaBOJHdNny1epzUoBt/view?usp=drive_link
|
||||
https://drive.google.com/file/d/18nOvxV1k8FSmBrhT4TPo2sKKSZXougyx/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1CT8FxclafFMjSd7gCWVw3VSeryeiF04i/view?usp=drive_link
|
||||
https://drive.google.com/file/d/16M9KVqQMFfSsXfypK0bocFft8Nz3j2Rt/view?usp=drive_link
|
||||
https://drive.google.com/file/d/18QPVkw6bj6HW8LTPrQLWrrUX4R6RcF42/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1hQTVtA5hBTE_StXpJafTZJ3tgt2VQQ_t/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Dn-d5g69H6EgAWgsFdrcbJKtz7ySsCQ8/view?usp=drive_link
|
||||
https://drive.google.com/file/d/13hMr16483P7ALYv73yMRUN37fJdVQM62/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1848yN3XMN5zJMEgApt6KzrWgfRPfimtv/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1oAD9kSnS0fTgj-CjD4u9VdZ5X67IOIMa/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ilzIWLCCG5b_KgF5s0wdN2I5-lFNpwC1/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1rjsT2YBjnidxod1s9s-myAYz8boHr-WB/view?usp=drive_link
|
||||
https://drive.google.com/file/d/18Gg48HTub15bd8qzbhiCUufbVy0fbN5G/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1WsSnQSqmMTVSRwrhT1Y-v782My2zcjLm/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ea9ZCvoyc-xqiFXgeDcA_mOWsw7VUuoi/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1wv1v3-XhPgbNzp62BXbJTDzMPu2tlDUc/view?usp=drive_link
|
||||
https://drive.google.com/file/d/18-ikzt8LoZ83Gi3goKCELs4U4z8hrRoF/view?usp=drive_link
|
||||
https://drive.google.com/file/d/16Bjhp7JNCXkGuLvyNcZowAx3W-Y-15DV/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Gc-KRI-xwcp1fMR55ugbrLg_5y3SPde-/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1oP72Q386Z4Sy5MMm-t5yNogIe5Van_9k/view?usp=drive_link
|
||||
https://drive.google.com/file/d/112T90eDUDVH-SyOV7UnZl5bscAH2hcfq/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1y-uKOesRRhjgDtFbG_j65f4SGg0v8XDg/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1LOP05OagoI3km-ZKQBrS204A85UVk7Ok/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1QkHQKgasVzWsmdPvkXgGhWyQ84d93_Az/view?usp=drive_link
|
||||
@@ -1 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1Ut2cv6o6Pkfgg46DgwVUM7Z5PkNG8eJ-
|
||||
@@ -1 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1FqxPV0PgvgIu8XFjtvZSPSExuNcxVVAY
|
||||
@@ -1,2 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1SKtG0ct9q0nVdYssJNMWSOjikcXliT58
|
||||
https://drive.google.com/file/d/1nchD21O30B3i3LDoqramo1zgW5YvpJIN/view?usp=drive_link
|
||||
@@ -1,2 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1_4DHf2cma0xsChLQFghwigX6Ukti5-zQ
|
||||
https://drive.google.com/file/d/1_8vS4hDNDgUQY-SmekrNaa7dF67QJYU-/view?usp=drive_link
|
||||
@@ -1,2 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1_4DHf2cma0xsChLQFghwigX6Ukti5-zQ
|
||||
https://drive.google.com/file/d/1_8vS4hDNDgUQY-SmekrNaa7dF67QJYU-/view?usp=drive_link
|
||||
@@ -1,2 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1fAD7vkyTGTFB_nGXIKofCU1U05oE3MFv
|
||||
https://drive.google.com/file/d/1XzyQ2B6LLvcurIonOpEu4nij2qwNWshH/view?usp=drive_link
|
||||
@@ -1,53 +0,0 @@
|
||||
https://drive.google.com/drive/folders/13EQsVsnxT86K20QAoyE_YpsFbQ7fZQdu
|
||||
https://drive.google.com/file/d/1-W_JHghZG65FNTVhw1SXhtQrazdLL3Ue/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1VwRJgdWUo-2nQaNM7Bs77-fsm8iwUxEo/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1wFzGRo5iYA13WLi6IV1ry64RyahQBFio/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1IKtQzQ-n-UTv64hYpReu2R4cqUvmNQqD/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1GicVci9OiuuZZH79i5Mg7AtWod94MzwT/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1JVnIoR7EIQp70T4eAf9RX65JcTrzsjQc/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1W2xr4h23ucjPrc-mBEeqnACsfaImpc0p/view?usp=drive_link
|
||||
https://drive.google.com/file/d/10xj_0V7A07o3uCa7v5omUrTC0YlPW8H3/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1FOc3EMaCy8Mb0_a7PuXLAwKwvxkbKmwU/view?usp=drive_link
|
||||
https://drive.google.com/file/d/143PgDXBcf2GQ0Q07ZPMVMfBgZDd5sLJG/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1pE5Tyj0LlGbGWvUzuhixp86Ibu55Ez3I/view?usp=drive_link
|
||||
https://drive.google.com/file/d/141668b1VzX80ncrVJPzhkoAeIFB4MEK9/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1bw12lo37p1ZvRvErHsll7cEYi2OxscvZ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1zfnMFvbgBjl6SzYhksbaOzfbwLrCN6tb/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1-GIszA6mUJMaNB-tdh9r9skc77SWA0VX/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1fTB0zWFYU6zh4IIUFT2zX_OkwYqmElwY/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1gPIPNKGmrO9c7gKF7SP0SuUYbIBBq8z1/view?usp=drive_link
|
||||
https://drive.google.com/file/d/12JeJ-dQd5lYyn6PlDOGdE-ChVeiZ-Uv0/view?usp=drive_link
|
||||
https://drive.google.com/file/d/100_20cgCqerU6qoh3TfTbwLy9mlDAFEG/view?usp=drive_link
|
||||
https://drive.google.com/file/d/111oAGJ76ku_pYgbBoIdZAC1_XEQcPI__/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1UhC8L-354ZQ2gblPFGI35EMsVwfpuKa0/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1sIXQSgUR_xdrNtGrL6QGBnkLMKErsIp1/view?usp=drive_link
|
||||
https://drive.google.com/file/d/16Ax77bDSIXnsn4GFL8XYKKT1P6bPpfMd/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1pgRVYwwVIsWq_qsWqZpe1UBzZfF5Fa9D/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1jtimaZkWsY1P5gC2bbS64H_WCUU7HXN2/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1N6Bh02P-RiTEgtx1YH1Db_X3TGpP-X_r/view?usp=drive_link
|
||||
https://drive.google.com/file/d/14Fy8EwJ8d9Vh97Yt1VOvUChSCrfIjBij/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1IRuv42dvIMPuKhcMZmuXaBjJ-lPFOmQd/view?usp=drive_link
|
||||
https://drive.google.com/file/d/16XWzNY2D8ucVVn5geBgsVdhm3ppO4que/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1xsVOoQgthK_L_SDrmq_JvQgUpAvPEAY8/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1bZbw66DyEMvnJnzkdUUNbKjvNKg8KFYM/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1CyTVkdrNGGpouCXr4CfhKbMzE6Ah3oo3/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1hDRyeM-XEDpHXpptbT8LvNnlQUR3PWOh/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1XhHWxbra8Iy5irQZ83IvxwaJqHq9x4s1/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1haZcn6aM1o4JlmP9tJj3x2enrxiPaDSD/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ypDyuUTbljaBZ34f-t7lj3O_0bRmyX2n/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ILEEZo_tA9_ChIAprr2mPaNVKZi5vXsO/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1U7nVYFaGE8vVTfLCW33D74xOjDcqfgyJ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1rZ93_rmCov5SMDxPkfM3qthcRELZrQX6/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1mYO1b_csddtyE3qT6cwLiw-m2w2_1Lxh/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1xz7Q5x2jikY8wJQjMRQpRws6AnfWlHm5/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1OO8GaO-0FrSZRd1kxMYwBmubyiLOWnbl/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1EXn4NVDmf-4_HCy34mYwT-vwK2CFI9ev/view?usp=drive_link
|
||||
https://drive.google.com/file/d/10hH70XhXRL9C5SnAG4toHtfHqfJUJo4H/view?usp=drive_link
|
||||
https://drive.google.com/file/d/18tiBcxea0guUai4lwsXQvt0q2LZ8ZnnJ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Q8R8qv37vk5PQ5kQ2ibx6BFLOySD0VpX/view?usp=drive_link
|
||||
https://drive.google.com/file/d/17aNriHzjhdibCyuUjQoMFZqjybJZtggG/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1LVjEYHSdeKm6CotU1QguIeNEPaIaFl_1/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ufAhE_EkgJ85slg2EW8aW_grOzE_Lmxd/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1wtzLtXrkw9eXRGESTPIOlpl1tInu-b2m/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Mk5qvVtD_QHwGOUApRq76TUw2T5THu6f/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1y1WQ3hboWVJ68KEYQQ3OhreGuaUpSgwc/view?usp=drive_link
|
||||
@@ -1,52 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1dxWh6YFZUDt6qXIoxgD9bla3CiFjZ11C
|
||||
https://drive.google.com/file/d/1hNBJN00SCAlOl0ZEgm7RRGbAGDjyBs0p/view?usp=drive_link
|
||||
https://drive.google.com/file/d/17He0CVwXGeoMmXg4SHKo-osNn7YPKVL7/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1laNKUVID1x2CV6a2O2WQjwFewKu4lidL/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1pNf36xbZJGRArYLmNAvRj5y6CoqdC6kB/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1_4E1-y3JXk5I0ebycLYM70YDPK9g52gZ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1PHfzhGPdbolKyOpS3FnR2w7Q8zUlJXSk/view?usp=drive_link
|
||||
https://drive.google.com/file/d/17ls2PPN-Pi3tEuK059cwV2_iDT8aGhOO/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1LWsg6PmCT00Kv_N_slrmcwKmQPGoBT3k/view?usp=drive_link
|
||||
https://drive.google.com/file/d/12LckrchoHTUVH7rxi8J7zD9dA19GXvoW/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1VqrJKjAIkj5gtFXL69grdSeu9CyaqnSw/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1g5rQYDBZvW-kUtYPeyF3qmd53v6k7kXu/view?usp=drive_link
|
||||
https://drive.google.com/file/d/10kUgaSJ0TS7teaG83G3Rf_DG4XGrBt6A/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1je9XmneZQZvTma5adMJICUPDovW3ppei/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1v28r6bedwZGbUPVVTVImXhK-42XdtGfj/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1-TEEx9sGVvzMMaNXYfQMtY2JJ6cvl0dT/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1YdBKdJFP9rJWBUX7qrOYL_gfUA8o6J9M/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1X9vffwQHNUSKLXr2RlYNtbWDIFCIDfdF/view?usp=drive_link
|
||||
https://drive.google.com/file/d/11hqesqa5kvEe5FABUnZRcvmOhR373cYM/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1ltTTECjEcbQPgS3UPRgMzaE2x9n6H7dC/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Zxqfa29JdwT-bfMpivi6IG2vz34d21dD/view?usp=drive_link
|
||||
https://drive.google.com/file/d/11LQlVxS5hz494dYUJ_PNRPx2NHIJbQns/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1i1JhNtnZpO_E8rAv8gxBP3ZTZRvcvsZi/view?usp=drive_link
|
||||
https://drive.google.com/file/d/11jOXAr2EULUO4Qkm748634lg4UUFho5U/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1rj67wur8DdB_Pipwx24bY43xu4X1eQ5e/view?usp=drive_link
|
||||
https://drive.google.com/file/d/15ZTm6lO6f_JQy_4SNfrOu3iPYn1Ro8mh/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1q4gBtqWPJtCwXEvknGgN0WHGp7Vfn1b9/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1t17keyre47AYqm8GgXiQ7EcvcUkeSiDQ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1OYUPGxtZgOF86Ng_BEOTXm_XOYpuQPsO/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1cBjbGHi3dwWHtx6r9EQJi0JT_CE3LuHt/view?usp=drive_link
|
||||
https://drive.google.com/file/d/14qaMyF0mcbCB-fCYKNyo5_2NahSC6D5u/view?usp=drive_link
|
||||
https://drive.google.com/file/d/12FgX86eA7Y5co9ULBVK80XMsiKQSs-Ri/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1yvoHWidf-jdBVw6qCCXOFfkVwKj_2hPk/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1a2SugsSDlC8UtUrFzp-_KAwyZckQOvdQ/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1l8pILBFSAosypWJMza2K09Vm7rug9axm/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1hfPQ8dBCk97PnOhq6_MIISm3IEzcOxJG/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1PPAUwlJCFKpms8cqF_k1v2_fCgDBOc3S/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1lVKQZeqFfK3amEmLuFhYLUFQ2eyE8rOW/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1K9iPMLfDowcIFoyzpvgn88dQ6x6kVwNG/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1PNvMqG9tL7QxeLaYBGHiWYR6SYb5iIct/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1xkRtzbvIkUsylx9hrFLGQsJn0h1EYu-5/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1nxMRrJlSayjDIfr5CmHO1NzAw3COhsLi/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Qs3WEyMGrmagiHIkkFEueWNnJhkUeR1s/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1D-G2_Q0SS3M8zyJbg_XzkF2ANPw1HTuX/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1mdmJsDGO-YtJAOF_yPKl6lq4PJOIbQhT/view?usp=drive_link
|
||||
https://drive.google.com/file/d/11m9bwfop_sPmnQr_8amB6EEsrbAeG_z5/view?usp=drive_link
|
||||
https://drive.google.com/file/d/19tyYt5FMn5kru0g9o2nMJhKPnsDqkIZv/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1XvTpUdsVTZ-vydvdYYmynbma--HfUGSl/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1MO3hFu68J6NohTzr9aB_fY02VA6QSOqj/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Lh-UjwAk__04YOTWINF_QGVU8SjetVaY/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1jkSOUwZV5GJ7rZlVeErjcu0DBQs8Np0d/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1VIN1eLI-93WrVQwCjsv6XQr353DqqBYA/view?usp=drive_link
|
||||
@@ -1,8 +0,0 @@
|
||||
https://drive.google.com/drive/folders/1EgKar7rWBmTIRmeJYZciSwjZx3uP2mHO
|
||||
https://drive.google.com/file/d/12eYWQO15atK2hBjXhynPJd9MKAj_42pz/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1Ul4oEeICJDjgfYTl4H1uaisTzVYIM6wd/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1WSF-OG8lKSe2wVYCv5D1aJNipxpgddk-/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1_ppD5j5sFh26aWW0JmhLzJMeNB-lCArk/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1WUp846dgWXYhu4oJfhHxiU6YL_7N6s4W/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1HRZNAIoAQw_uYiPwnBvtBioQoqiqoXdA/view?usp=drive_link
|
||||
https://drive.google.com/file/d/1hedGq-QDMnIn8GlXXBC3GiEJ_Y-LTxyt/view?usp=drive_link
|
||||
@@ -1,634 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
|
||||
|
||||
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import numbers
|
||||
import os
|
||||
from functools import cached_property
|
||||
|
||||
import numcodecs
|
||||
import numpy as np
|
||||
import zarr
|
||||
|
||||
|
||||
def check_chunks_compatible(chunks: tuple, shape: tuple):
|
||||
assert len(shape) == len(chunks)
|
||||
for c in chunks:
|
||||
assert isinstance(c, numbers.Integral)
|
||||
assert c > 0
|
||||
|
||||
|
||||
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
|
||||
old_arr = group[name]
|
||||
if chunks is None:
|
||||
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
|
||||
check_chunks_compatible(chunks, old_arr.shape)
|
||||
|
||||
if compressor is None:
|
||||
compressor = old_arr.compressor
|
||||
|
||||
if (chunks == old_arr.chunks) and (compressor == old_arr.compressor):
|
||||
# no change
|
||||
return old_arr
|
||||
|
||||
# rechunk recompress
|
||||
group.move(name, tmp_key)
|
||||
old_arr = group[tmp_key]
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
source=old_arr,
|
||||
dest=group,
|
||||
name=name,
|
||||
chunks=chunks,
|
||||
compressor=compressor,
|
||||
)
|
||||
del group[tmp_key]
|
||||
arr = group[name]
|
||||
return arr
|
||||
|
||||
|
||||
def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=None):
|
||||
"""
|
||||
Common shapes
|
||||
T,D
|
||||
T,N,D
|
||||
T,H,W,C
|
||||
T,N,H,W,C
|
||||
"""
|
||||
itemsize = np.dtype(dtype).itemsize
|
||||
# reversed
|
||||
rshape = list(shape[::-1])
|
||||
if max_chunk_length is not None:
|
||||
rshape[-1] = int(max_chunk_length)
|
||||
split_idx = len(shape) - 1
|
||||
for i in range(len(shape) - 1):
|
||||
this_chunk_bytes = itemsize * np.prod(rshape[:i])
|
||||
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
|
||||
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
|
||||
split_idx = i
|
||||
|
||||
rchunks = rshape[:split_idx]
|
||||
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
|
||||
this_max_chunk_length = rshape[split_idx]
|
||||
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
|
||||
rchunks.append(next_chunk_length)
|
||||
len_diff = len(shape) - len(rchunks)
|
||||
rchunks.extend([1] * len_diff)
|
||||
chunks = tuple(rchunks[::-1])
|
||||
# print(np.prod(chunks) * itemsize / target_chunk_bytes)
|
||||
return chunks
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
"""
|
||||
Zarr-based temporal datastructure.
|
||||
Assumes first dimension to be time. Only chunk in time dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, root: zarr.Group | dict[str, dict]):
|
||||
"""
|
||||
Dummy constructor. Use copy_from* and create_from* class methods instead.
|
||||
"""
|
||||
assert "data" in root
|
||||
assert "meta" in root
|
||||
assert "episode_ends" in root["meta"]
|
||||
for value in root["data"].values():
|
||||
assert value.shape[0] == root["meta"]["episode_ends"][-1]
|
||||
self.root = root
|
||||
|
||||
# ============= create constructors ===============
|
||||
@classmethod
|
||||
def create_empty_zarr(cls, storage=None, root=None):
|
||||
if root is None:
|
||||
if storage is None:
|
||||
storage = zarr.MemoryStore()
|
||||
root = zarr.group(store=storage)
|
||||
root.require_group("data", overwrite=False)
|
||||
meta = root.require_group("meta", overwrite=False)
|
||||
if "episode_ends" not in meta:
|
||||
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
|
||||
return cls(root=root)
|
||||
|
||||
@classmethod
|
||||
def create_empty_numpy(cls):
|
||||
root = {"data": {}, "meta": {"episode_ends": np.zeros((0,), dtype=np.int64)}}
|
||||
return cls(root=root)
|
||||
|
||||
@classmethod
|
||||
def create_from_group(cls, group, **kwargs):
|
||||
if "data" not in group:
|
||||
# create from stratch
|
||||
buffer = cls.create_empty_zarr(root=group, **kwargs)
|
||||
else:
|
||||
# already exist
|
||||
buffer = cls(root=group, **kwargs)
|
||||
return buffer
|
||||
|
||||
@classmethod
|
||||
def create_from_path(cls, zarr_path, mode="r", **kwargs):
|
||||
"""
|
||||
Open a on-disk zarr directly (for dataset larger than memory).
|
||||
Slower.
|
||||
"""
|
||||
group = zarr.open(os.path.expanduser(zarr_path), mode)
|
||||
return cls.create_from_group(group, **kwargs)
|
||||
|
||||
# ============= copy constructors ===============
|
||||
@classmethod
|
||||
def copy_from_store(
|
||||
cls,
|
||||
src_store,
|
||||
store=None,
|
||||
keys=None,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load to memory.
|
||||
"""
|
||||
src_root = zarr.group(src_store)
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
root = None
|
||||
if store is None:
|
||||
# numpy backend
|
||||
meta = {}
|
||||
for key, value in src_root["meta"].items():
|
||||
if len(value.shape) == 0:
|
||||
meta[key] = np.array(value)
|
||||
else:
|
||||
meta[key] = value[:]
|
||||
|
||||
if keys is None:
|
||||
keys = src_root["data"].keys()
|
||||
data = {}
|
||||
for key in keys:
|
||||
arr = src_root["data"][key]
|
||||
data[key] = arr[:]
|
||||
|
||||
root = {"meta": meta, "data": data}
|
||||
else:
|
||||
root = zarr.group(store=store)
|
||||
# copy without recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||
)
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
if keys is None:
|
||||
keys = src_root["data"].keys()
|
||||
for key in keys:
|
||||
value = src_root["data"][key]
|
||||
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=src_store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
dest_path=this_path,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
# copy with recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
source=value,
|
||||
dest=data_group,
|
||||
name=key,
|
||||
chunks=cks,
|
||||
compressor=cpr,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
buffer = cls(root=root)
|
||||
return buffer
|
||||
|
||||
@classmethod
|
||||
def copy_from_path(
|
||||
cls,
|
||||
zarr_path,
|
||||
backend=None,
|
||||
store=None,
|
||||
keys=None,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Copy a on-disk zarr to in-memory compressed.
|
||||
Recommended
|
||||
"""
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
if backend == "numpy":
|
||||
print("backend argument is deprecated!")
|
||||
store = None
|
||||
group = zarr.open(os.path.expanduser(zarr_path), "r")
|
||||
return cls.copy_from_store(
|
||||
src_store=group.store,
|
||||
store=store,
|
||||
keys=keys,
|
||||
chunks=chunks,
|
||||
compressors=compressors,
|
||||
if_exists=if_exists,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ============= save methods ===============
|
||||
def save_to_store(
|
||||
self,
|
||||
store,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
root = zarr.group(store)
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
if self.backend == "zarr":
|
||||
# recompression free copy
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path="/meta",
|
||||
dest_path="/meta",
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
meta_group = root.create_group("meta", overwrite=True)
|
||||
# save meta, no chunking
|
||||
for key, value in self.root["meta"].items():
|
||||
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
|
||||
|
||||
# save data, chunk
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
for key, value in self.root["data"].items():
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
if isinstance(value, zarr.Array):
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
dest_path=this_path,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
# copy with recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
source=value,
|
||||
dest=data_group,
|
||||
name=key,
|
||||
chunks=cks,
|
||||
compressor=cpr,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
# numpy
|
||||
_ = data_group.array(name=key, data=value, chunks=cks, compressor=cpr)
|
||||
return store
|
||||
|
||||
def save_to_path(
|
||||
self,
|
||||
zarr_path,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
store = zarr.DirectoryStore(os.path.expanduser(zarr_path))
|
||||
return self.save_to_store(
|
||||
store, chunks=chunks, compressors=compressors, if_exists=if_exists, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def resolve_compressor(compressor="default"):
|
||||
if compressor == "default":
|
||||
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
|
||||
elif compressor == "disk":
|
||||
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
|
||||
return compressor
|
||||
|
||||
@classmethod
|
||||
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
|
||||
# allows compressor to be explicitly set to None
|
||||
cpr = "nil"
|
||||
if isinstance(compressors, dict):
|
||||
if key in compressors:
|
||||
cpr = cls.resolve_compressor(compressors[key])
|
||||
elif isinstance(array, zarr.Array):
|
||||
cpr = array.compressor
|
||||
else:
|
||||
cpr = cls.resolve_compressor(compressors)
|
||||
# backup default
|
||||
if cpr == "nil":
|
||||
cpr = cls.resolve_compressor("default")
|
||||
return cpr
|
||||
|
||||
@classmethod
|
||||
def _resolve_array_chunks(cls, chunks: dict | tuple, key, array):
|
||||
cks = None
|
||||
if isinstance(chunks, dict):
|
||||
if key in chunks:
|
||||
cks = chunks[key]
|
||||
elif isinstance(array, zarr.Array):
|
||||
cks = array.chunks
|
||||
elif isinstance(chunks, tuple):
|
||||
cks = chunks
|
||||
else:
|
||||
raise TypeError(f"Unsupported chunks type {type(chunks)}")
|
||||
# backup default
|
||||
if cks is None:
|
||||
cks = get_optimal_chunks(shape=array.shape, dtype=array.dtype)
|
||||
# check
|
||||
check_chunks_compatible(chunks=cks, shape=array.shape)
|
||||
return cks
|
||||
|
||||
# ============= properties =================
|
||||
@cached_property
|
||||
def data(self):
|
||||
return self.root["data"]
|
||||
|
||||
@cached_property
|
||||
def meta(self):
|
||||
return self.root["meta"]
|
||||
|
||||
def update_meta(self, data):
|
||||
# sanitize data
|
||||
np_data = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
np_data[key] = value
|
||||
else:
|
||||
arr = np.array(value)
|
||||
if arr.dtype == object:
|
||||
raise TypeError(f"Invalid value type {type(value)}")
|
||||
np_data[key] = arr
|
||||
|
||||
meta_group = self.meta
|
||||
if self.backend == "zarr":
|
||||
for key, value in np_data.items():
|
||||
_ = meta_group.array(
|
||||
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
|
||||
)
|
||||
else:
|
||||
meta_group.update(np_data)
|
||||
|
||||
return meta_group
|
||||
|
||||
@property
|
||||
def episode_ends(self):
|
||||
return self.meta["episode_ends"]
|
||||
|
||||
def get_episode_idxs(self):
|
||||
import numba
|
||||
|
||||
numba.jit(nopython=True)
|
||||
|
||||
def _get_episode_idxs(episode_ends):
|
||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||
for i in range(len(episode_ends)):
|
||||
start = 0
|
||||
if i > 0:
|
||||
start = episode_ends[i - 1]
|
||||
end = episode_ends[i]
|
||||
for idx in range(start, end):
|
||||
result[idx] = i
|
||||
return result
|
||||
|
||||
return _get_episode_idxs(self.episode_ends)
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
backend = "numpy"
|
||||
if isinstance(self.root, zarr.Group):
|
||||
backend = "zarr"
|
||||
return backend
|
||||
|
||||
# =========== dict-like API ==============
|
||||
def __repr__(self) -> str:
|
||||
if self.backend == "zarr":
|
||||
return str(self.root.tree())
|
||||
else:
|
||||
return super().__repr__()
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
def values(self):
|
||||
return self.data.values()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.data
|
||||
|
||||
# =========== our API ==============
|
||||
@property
|
||||
def n_steps(self):
|
||||
if len(self.episode_ends) == 0:
|
||||
return 0
|
||||
return self.episode_ends[-1]
|
||||
|
||||
@property
|
||||
def n_episodes(self):
|
||||
return len(self.episode_ends)
|
||||
|
||||
@property
|
||||
def chunk_size(self):
|
||||
if self.backend == "zarr":
|
||||
return next(iter(self.data.arrays()))[-1].chunks[0]
|
||||
return None
|
||||
|
||||
@property
|
||||
def episode_lengths(self):
|
||||
ends = self.episode_ends[:]
|
||||
ends = np.insert(ends, 0, 0)
|
||||
lengths = np.diff(ends)
|
||||
return lengths
|
||||
|
||||
def add_episode(
|
||||
self,
|
||||
data: dict[str, np.ndarray],
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||
):
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
assert len(data) > 0
|
||||
is_zarr = self.backend == "zarr"
|
||||
|
||||
curr_len = self.n_steps
|
||||
episode_length = None
|
||||
for value in data.values():
|
||||
assert len(value.shape) >= 1
|
||||
if episode_length is None:
|
||||
episode_length = len(value)
|
||||
else:
|
||||
assert episode_length == len(value)
|
||||
new_len = curr_len + episode_length
|
||||
|
||||
for key, value in data.items():
|
||||
new_shape = (new_len,) + value.shape[1:]
|
||||
# create array
|
||||
if key not in self.data:
|
||||
if is_zarr:
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
arr = self.data.zeros(
|
||||
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
|
||||
)
|
||||
else:
|
||||
# copy data to prevent modify
|
||||
arr = np.zeros(shape=new_shape, dtype=value.dtype)
|
||||
self.data[key] = arr
|
||||
else:
|
||||
arr = self.data[key]
|
||||
assert value.shape[1:] == arr.shape[1:]
|
||||
# same method for both zarr and numpy
|
||||
if is_zarr:
|
||||
arr.resize(new_shape)
|
||||
else:
|
||||
arr.resize(new_shape, refcheck=False)
|
||||
# copy data
|
||||
arr[-value.shape[0] :] = value
|
||||
|
||||
# append to episode ends
|
||||
episode_ends = self.episode_ends
|
||||
if is_zarr:
|
||||
episode_ends.resize(episode_ends.shape[0] + 1)
|
||||
else:
|
||||
episode_ends.resize(episode_ends.shape[0] + 1, refcheck=False)
|
||||
episode_ends[-1] = new_len
|
||||
|
||||
# rechunk
|
||||
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
|
||||
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
|
||||
|
||||
def drop_episode(self):
|
||||
is_zarr = self.backend == "zarr"
|
||||
episode_ends = self.episode_ends[:].copy()
|
||||
assert len(episode_ends) > 0
|
||||
start_idx = 0
|
||||
if len(episode_ends) > 1:
|
||||
start_idx = episode_ends[-2]
|
||||
for value in self.data.values():
|
||||
new_shape = (start_idx,) + value.shape[1:]
|
||||
if is_zarr:
|
||||
value.resize(new_shape)
|
||||
else:
|
||||
value.resize(new_shape, refcheck=False)
|
||||
if is_zarr:
|
||||
self.episode_ends.resize(len(episode_ends) - 1)
|
||||
else:
|
||||
self.episode_ends.resize(len(episode_ends) - 1, refcheck=False)
|
||||
|
||||
def pop_episode(self):
|
||||
assert self.n_episodes > 0
|
||||
episode = self.get_episode(self.n_episodes - 1, copy=True)
|
||||
self.drop_episode()
|
||||
return episode
|
||||
|
||||
def extend(self, data):
|
||||
self.add_episode(data)
|
||||
|
||||
def get_episode(self, idx, copy=False):
|
||||
idx = list(range(len(self.episode_ends)))[idx]
|
||||
start_idx = 0
|
||||
if idx > 0:
|
||||
start_idx = self.episode_ends[idx - 1]
|
||||
end_idx = self.episode_ends[idx]
|
||||
result = self.get_steps_slice(start_idx, end_idx, copy=copy)
|
||||
return result
|
||||
|
||||
def get_episode_slice(self, idx):
|
||||
start_idx = 0
|
||||
if idx > 0:
|
||||
start_idx = self.episode_ends[idx - 1]
|
||||
end_idx = self.episode_ends[idx]
|
||||
return slice(start_idx, end_idx)
|
||||
|
||||
def get_steps_slice(self, start, stop, step=None, copy=False):
|
||||
_slice = slice(start, stop, step)
|
||||
|
||||
result = {}
|
||||
for key, value in self.data.items():
|
||||
x = value[_slice]
|
||||
if copy and isinstance(value, np.ndarray):
|
||||
x = x.copy()
|
||||
result[key] = x
|
||||
return result
|
||||
|
||||
# =========== chunking =============
|
||||
def get_chunks(self) -> dict:
|
||||
assert self.backend == "zarr"
|
||||
chunks = {}
|
||||
for key, value in self.data.items():
|
||||
chunks[key] = value.chunks
|
||||
return chunks
|
||||
|
||||
def set_chunks(self, chunks: dict):
|
||||
assert self.backend == "zarr"
|
||||
for key, value in chunks.items():
|
||||
if key in self.data:
|
||||
arr = self.data[key]
|
||||
if value != arr.chunks:
|
||||
check_chunks_compatible(chunks=value, shape=arr.shape)
|
||||
rechunk_recompress_array(self.data, key, chunks=value)
|
||||
|
||||
def get_compressors(self) -> dict:
|
||||
assert self.backend == "zarr"
|
||||
compressors = {}
|
||||
for key, value in self.data.items():
|
||||
compressors[key] = value.compressor
|
||||
return compressors
|
||||
|
||||
def set_compressors(self, compressors: dict):
|
||||
assert self.backend == "zarr"
|
||||
for key, value in compressors.items():
|
||||
if key in self.data:
|
||||
arr = self.data[key]
|
||||
compressor = self.resolve_compressor(value)
|
||||
if compressor != arr.compressor:
|
||||
rechunk_recompress_array(self.data, key, compressor=compressor)
|
||||
@@ -1,202 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file contains download scripts for raw datasets.
|
||||
|
||||
Example of usage:
|
||||
```
|
||||
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
|
||||
--raw-dir data/lerobot-raw/pusht_raw \
|
||||
--repo-id lerobot-raw/pusht_raw
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
|
||||
# {raw_repo_id: raw_format}
|
||||
AVAILABLE_RAW_REPO_IDS = {
|
||||
"lerobot-raw/aloha_mobile_cabinet_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_chair_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_elevator_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_shrimp_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_wash_pan_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_wipe_wine_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_insertion_human_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_insertion_scripted_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_transfer_cube_human_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_transfer_cube_scripted_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_battery_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_candy_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_coffee_new_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_coffee_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_cups_open_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_fork_pick_up_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_pingpong_test_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_pro_pencil_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_screw_driver_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_tape_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_thread_velcro_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_towel_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_vinh_cup_left_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_vinh_cup_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_ziploc_slide_raw": "aloha_hdf5",
|
||||
"lerobot-raw/umi_cup_in_the_wild_raw": "umi_zarr",
|
||||
"lerobot-raw/pusht_raw": "pusht_zarr",
|
||||
"lerobot-raw/unitreeh1_fold_clothes_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_rearrange_objects_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_two_robot_greeting_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_warehouse_raw": "aloha_hdf5",
|
||||
"lerobot-raw/xarm_lift_medium_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_lift_medium_replay_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_push_medium_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_push_medium_replay_raw": "xarm_pkl",
|
||||
"lerobot-raw/fractal20220817_data_raw": "openx_rlds.fractal20220817_data",
|
||||
"lerobot-raw/kuka_raw": "openx_rlds.kuka",
|
||||
"lerobot-raw/bridge_openx_raw": "openx_rlds.bridge_openx",
|
||||
"lerobot-raw/taco_play_raw": "openx_rlds.taco_play",
|
||||
"lerobot-raw/jaco_play_raw": "openx_rlds.jaco_play",
|
||||
"lerobot-raw/berkeley_cable_routing_raw": "openx_rlds.berkeley_cable_routing",
|
||||
"lerobot-raw/roboturk_raw": "openx_rlds.roboturk",
|
||||
"lerobot-raw/nyu_door_opening_surprising_effectiveness_raw": "openx_rlds.nyu_door_opening_surprising_effectiveness",
|
||||
"lerobot-raw/viola_raw": "openx_rlds.viola",
|
||||
"lerobot-raw/berkeley_autolab_ur5_raw": "openx_rlds.berkeley_autolab_ur5",
|
||||
"lerobot-raw/toto_raw": "openx_rlds.toto",
|
||||
"lerobot-raw/language_table_raw": "openx_rlds.language_table",
|
||||
"lerobot-raw/columbia_cairlab_pusht_real_raw": "openx_rlds.columbia_cairlab_pusht_real",
|
||||
"lerobot-raw/stanford_kuka_multimodal_dataset_raw": "openx_rlds.stanford_kuka_multimodal_dataset",
|
||||
"lerobot-raw/nyu_rot_dataset_raw": "openx_rlds.nyu_rot_dataset",
|
||||
"lerobot-raw/io_ai_tech_raw": "openx_rlds.io_ai_tech",
|
||||
"lerobot-raw/stanford_hydra_dataset_raw": "openx_rlds.stanford_hydra_dataset",
|
||||
"lerobot-raw/austin_buds_dataset_raw": "openx_rlds.austin_buds_dataset",
|
||||
"lerobot-raw/nyu_franka_play_dataset_raw": "openx_rlds.nyu_franka_play_dataset",
|
||||
"lerobot-raw/maniskill_dataset_raw": "openx_rlds.maniskill_dataset",
|
||||
"lerobot-raw/furniture_bench_dataset_raw": "openx_rlds.furniture_bench_dataset",
|
||||
"lerobot-raw/cmu_franka_exploration_dataset_raw": "openx_rlds.cmu_franka_exploration_dataset",
|
||||
"lerobot-raw/ucsd_kitchen_dataset_raw": "openx_rlds.ucsd_kitchen_dataset",
|
||||
"lerobot-raw/ucsd_pick_and_place_dataset_raw": "openx_rlds.ucsd_pick_and_place_dataset",
|
||||
"lerobot-raw/spoc_raw": "openx_rlds.spoc",
|
||||
"lerobot-raw/austin_sailor_dataset_raw": "openx_rlds.austin_sailor_dataset",
|
||||
"lerobot-raw/austin_sirius_dataset_raw": "openx_rlds.austin_sirius_dataset",
|
||||
"lerobot-raw/bc_z_raw": "openx_rlds.bc_z",
|
||||
"lerobot-raw/utokyo_pr2_opening_fridge_raw": "openx_rlds.utokyo_pr2_opening_fridge",
|
||||
"lerobot-raw/utokyo_pr2_tabletop_manipulation_raw": "openx_rlds.utokyo_pr2_tabletop_manipulation",
|
||||
"lerobot-raw/utokyo_xarm_pick_and_place_raw": "openx_rlds.utokyo_xarm_pick_and_place",
|
||||
"lerobot-raw/utokyo_xarm_bimanual_raw": "openx_rlds.utokyo_xarm_bimanual",
|
||||
"lerobot-raw/utokyo_saytap_raw": "openx_rlds.utokyo_saytap",
|
||||
"lerobot-raw/robo_net_raw": "openx_rlds.robo_net",
|
||||
"lerobot-raw/robo_set_raw": "openx_rlds.robo_set",
|
||||
"lerobot-raw/berkeley_mvp_raw": "openx_rlds.berkeley_mvp",
|
||||
"lerobot-raw/berkeley_rpt_raw": "openx_rlds.berkeley_rpt",
|
||||
"lerobot-raw/kaist_nonprehensile_raw": "openx_rlds.kaist_nonprehensile",
|
||||
"lerobot-raw/stanford_mask_vit_raw": "openx_rlds.stanford_mask_vit",
|
||||
"lerobot-raw/tokyo_u_lsmo_raw": "openx_rlds.tokyo_u_lsmo",
|
||||
"lerobot-raw/dlr_sara_pour_raw": "openx_rlds.dlr_sara_pour",
|
||||
"lerobot-raw/dlr_sara_grid_clamp_raw": "openx_rlds.dlr_sara_grid_clamp",
|
||||
"lerobot-raw/dlr_edan_shared_control_raw": "openx_rlds.dlr_edan_shared_control",
|
||||
"lerobot-raw/asu_table_top_raw": "openx_rlds.asu_table_top",
|
||||
"lerobot-raw/stanford_robocook_raw": "openx_rlds.stanford_robocook",
|
||||
"lerobot-raw/imperialcollege_sawyer_wrist_cam_raw": "openx_rlds.imperialcollege_sawyer_wrist_cam",
|
||||
"lerobot-raw/iamlab_cmu_pickup_insert_raw": "openx_rlds.iamlab_cmu_pickup_insert",
|
||||
"lerobot-raw/uiuc_d3field_raw": "openx_rlds.uiuc_d3field",
|
||||
"lerobot-raw/utaustin_mutex_raw": "openx_rlds.utaustin_mutex",
|
||||
"lerobot-raw/berkeley_fanuc_manipulation_raw": "openx_rlds.berkeley_fanuc_manipulation",
|
||||
"lerobot-raw/cmu_playing_with_food_raw": "openx_rlds.cmu_playing_with_food",
|
||||
"lerobot-raw/cmu_play_fusion_raw": "openx_rlds.cmu_play_fusion",
|
||||
"lerobot-raw/cmu_stretch_raw": "openx_rlds.cmu_stretch",
|
||||
"lerobot-raw/berkeley_gnm_recon_raw": "openx_rlds.berkeley_gnm_recon",
|
||||
"lerobot-raw/berkeley_gnm_cory_hall_raw": "openx_rlds.berkeley_gnm_cory_hall",
|
||||
"lerobot-raw/berkeley_gnm_sac_son_raw": "openx_rlds.berkeley_gnm_sac_son",
|
||||
"lerobot-raw/droid_raw": "openx_rlds.droid",
|
||||
"lerobot-raw/droid_100_raw": "openx_rlds.droid100",
|
||||
"lerobot-raw/fmb_raw": "openx_rlds.fmb",
|
||||
"lerobot-raw/dobbe_raw": "openx_rlds.dobbe",
|
||||
"lerobot-raw/usc_cloth_sim_raw": "openx_rlds.usc_cloth_sim",
|
||||
"lerobot-raw/plex_robosuite_raw": "openx_rlds.plex_robosuite",
|
||||
"lerobot-raw/conq_hose_manipulation_raw": "openx_rlds.conq_hose_manipulation",
|
||||
"lerobot-raw/vima_raw": "openx_rlds.vima",
|
||||
"lerobot-raw/robot_vqa_raw": "openx_rlds.robot_vqa",
|
||||
"lerobot-raw/mimic_play_raw": "openx_rlds.mimic_play",
|
||||
"lerobot-raw/tidybot_raw": "openx_rlds.tidybot",
|
||||
"lerobot-raw/eth_agent_affordances_raw": "openx_rlds.eth_agent_affordances",
|
||||
}
|
||||
|
||||
|
||||
def download_raw(raw_dir: Path, repo_id: str):
|
||||
check_repo_id(repo_id)
|
||||
user_id, dataset_id = repo_id.split("/")
|
||||
|
||||
if not dataset_id.endswith("_raw"):
|
||||
warnings.warn(
|
||||
f"""`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this
|
||||
naming convention by renaming your repository is advised, but not mandatory.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
# Send warning if raw_dir isn't well formatted
|
||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
|
||||
match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised,
|
||||
but not mandatory.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
|
||||
|
||||
def download_all_raw_datasets(data_dir: Path | None = None):
|
||||
if data_dir is None:
|
||||
data_dir = Path("data")
|
||||
for repo_id in AVAILABLE_RAW_REPO_IDS:
|
||||
raw_dir = data_dir / repo_id
|
||||
download_raw(raw_dir, repo_id)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=f"""A script to download raw datasets from Hugging Face hub to a local directory. Here is a
|
||||
non exhaustive list of available repositories to use in `--repo-id`: {list(AVAILABLE_RAW_REPO_IDS.keys())}""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""Repositery identifier on Hugging Face: a community or a user name `/` the name of
|
||||
the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).""",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
download_raw(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,184 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Use this script to batch encode lerobot dataset from their raw format to LeRobotDataset and push their updated
|
||||
version to the hub. Under the hood, this script reuses 'push_dataset_to_hub.py'. It assumes that you already
|
||||
downloaded raw datasets, which you can do with the related '_download_raw.py' script.
|
||||
|
||||
For instance, for codebase_version = 'v1.6', the following command was run, assuming raw datasets from
|
||||
lerobot-raw were downloaded in 'raw/datasets/directory':
|
||||
```bash
|
||||
python lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py \
|
||||
--raw-dir raw/datasets/directory \
|
||||
--raw-repo-ids lerobot-raw \
|
||||
--local-dir push/datasets/directory \
|
||||
--tests-data-dir tests/data \
|
||||
--push-repo lerobot \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 2 \
|
||||
--crf 30
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||
|
||||
|
||||
def get_push_repo_id_from_raw(raw_repo_id: str, push_repo: str) -> str:
|
||||
dataset_id_raw = raw_repo_id.split("/")[1]
|
||||
dataset_id = dataset_id_raw.removesuffix("_raw")
|
||||
return f"{push_repo}/{dataset_id}"
|
||||
|
||||
|
||||
def encode_datasets(
|
||||
raw_dir: Path,
|
||||
raw_repo_ids: list[str],
|
||||
push_repo: str,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
local_dir: Path | None = None,
|
||||
tests_data_dir: Path | None = None,
|
||||
raw_format: str | None = None,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
if len(raw_repo_ids) == 1 and raw_repo_ids[0].lower() == "lerobot-raw":
|
||||
raw_repo_ids_format = AVAILABLE_RAW_REPO_IDS
|
||||
else:
|
||||
if raw_format is None:
|
||||
raise ValueError(raw_format)
|
||||
raw_repo_ids_format = {id_: raw_format for id_ in raw_repo_ids}
|
||||
|
||||
for raw_repo_id, repo_raw_format in raw_repo_ids_format.items():
|
||||
check_repo_id(raw_repo_id)
|
||||
dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
|
||||
dataset_raw_dir = raw_dir / raw_repo_id
|
||||
dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
|
||||
encoding = {
|
||||
"vcodec": vcodec,
|
||||
"pix_fmt": pix_fmt,
|
||||
"g": g,
|
||||
"crf": crf,
|
||||
}
|
||||
|
||||
if not (dataset_raw_dir).is_dir():
|
||||
raise NotADirectoryError(dataset_raw_dir)
|
||||
|
||||
if not dry_run:
|
||||
push_dataset_to_hub(
|
||||
dataset_raw_dir,
|
||||
raw_format=repo_raw_format,
|
||||
repo_id=dataset_repo_id_push,
|
||||
local_dir=dataset_dir,
|
||||
resume=True,
|
||||
encoding=encoding,
|
||||
tests_data_dir=tests_data_dir,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"DRY RUN: {dataset_raw_dir} --> {dataset_dir} --> {dataset_repo_id_push}@{CODEBASE_VERSION}"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
default=Path("data"),
|
||||
help="Directory where raw datasets are located.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raw-repo-ids",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["lerobot-raw"],
|
||||
help="""Raw dataset repo ids. if 'lerobot-raw', the keys from `AVAILABLE_RAW_REPO_IDS` will be
|
||||
used and raw datasets will be fetched from the 'lerobot-raw/' repo and pushed with their
|
||||
associated format. It is assumed that each dataset is located at `raw_dir / raw_repo_id` """,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raw-format",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""Raw format to use for the raw repo-ids. Must be specified if --raw-repo-ids is not
|
||||
'lerobot-raw'""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="""When provided, writes the dataset converted to LeRobotDataset format in this directory
|
||||
(e.g. `data/lerobot/aloha_mobile_chair`).""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-repo",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="Repo to upload datasets to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vcodec",
|
||||
type=str,
|
||||
default="libsvtav1",
|
||||
help="Codec to use for encoding videos",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pix-fmt",
|
||||
type=str,
|
||||
default="yuv420p",
|
||||
help="Pixel formats (chroma subsampling) to be used for encoding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Group of pictures sizes to be used for encoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crf",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Constant rate factors to be used for encoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help=(
|
||||
"When provided, save tests artifacts into the given directory "
|
||||
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
type=int,
|
||||
default=0,
|
||||
help="If not set to 0, this script won't download or upload anything.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
encode_datasets(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,326 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# imagecodecs/numcodecs.py
|
||||
|
||||
# Copyright (c) 2021-2022, Christoph Gohlke
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
# POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Copied from: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/codecs/imagecodecs_numcodecs.py#L1
|
||||
"""Additional numcodecs implemented using imagecodecs."""
|
||||
|
||||
__version__ = "2022.9.26"
|
||||
|
||||
__all__ = ("register_codecs",)
|
||||
|
||||
import imagecodecs
|
||||
import numpy
|
||||
from numcodecs.abc import Codec
|
||||
from numcodecs.registry import get_codec, register_codec
|
||||
|
||||
# TODO (azouitine): Remove useless codecs
|
||||
|
||||
|
||||
def protective_squeeze(x: numpy.ndarray):
|
||||
"""
|
||||
Squeeze dim only if it's not the last dim.
|
||||
Image dim expected to be *, H, W, C
|
||||
"""
|
||||
img_shape = x.shape[-3:]
|
||||
if len(x.shape) > 3:
|
||||
n_imgs = numpy.prod(x.shape[:-3])
|
||||
if n_imgs > 1:
|
||||
img_shape = (-1,) + img_shape
|
||||
return x.reshape(img_shape)
|
||||
|
||||
|
||||
def get_default_image_compressor(**kwargs):
|
||||
if imagecodecs.JPEGXL:
|
||||
# has JPEGXL
|
||||
this_kwargs = {
|
||||
"effort": 3,
|
||||
"distance": 0.3,
|
||||
# bug in libjxl, invalid codestream for non-lossless
|
||||
# when decoding speed > 1
|
||||
"decodingspeed": 1,
|
||||
}
|
||||
this_kwargs.update(kwargs)
|
||||
return JpegXl(**this_kwargs)
|
||||
else:
|
||||
this_kwargs = {"level": 50}
|
||||
this_kwargs.update(kwargs)
|
||||
return Jpeg2k(**this_kwargs)
|
||||
|
||||
|
||||
class Jpeg2k(Codec):
|
||||
"""JPEG 2000 codec for numcodecs."""
|
||||
|
||||
codec_id = "imagecodecs_jpeg2k"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
level=None,
|
||||
codecformat=None,
|
||||
colorspace=None,
|
||||
tile=None,
|
||||
reversible=None,
|
||||
bitspersample=None,
|
||||
resolutions=None,
|
||||
numthreads=None,
|
||||
verbose=0,
|
||||
):
|
||||
self.level = level
|
||||
self.codecformat = codecformat
|
||||
self.colorspace = colorspace
|
||||
self.tile = None if tile is None else tuple(tile)
|
||||
self.reversible = reversible
|
||||
self.bitspersample = bitspersample
|
||||
self.resolutions = resolutions
|
||||
self.numthreads = numthreads
|
||||
self.verbose = verbose
|
||||
|
||||
def encode(self, buf):
|
||||
buf = protective_squeeze(numpy.asarray(buf))
|
||||
return imagecodecs.jpeg2k_encode(
|
||||
buf,
|
||||
level=self.level,
|
||||
codecformat=self.codecformat,
|
||||
colorspace=self.colorspace,
|
||||
tile=self.tile,
|
||||
reversible=self.reversible,
|
||||
bitspersample=self.bitspersample,
|
||||
resolutions=self.resolutions,
|
||||
numthreads=self.numthreads,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
def decode(self, buf, out=None):
|
||||
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
|
||||
|
||||
|
||||
class JpegXl(Codec):
|
||||
"""JPEG XL codec for numcodecs."""
|
||||
|
||||
codec_id = "imagecodecs_jpegxl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# encode
|
||||
level=None,
|
||||
effort=None,
|
||||
distance=None,
|
||||
lossless=None,
|
||||
decodingspeed=None,
|
||||
photometric=None,
|
||||
planar=None,
|
||||
usecontainer=None,
|
||||
# decode
|
||||
index=None,
|
||||
keeporientation=None,
|
||||
# both
|
||||
numthreads=None,
|
||||
):
|
||||
"""
|
||||
Return JPEG XL image from numpy array.
|
||||
Float must be in nominal range 0..1.
|
||||
|
||||
Currently L, LA, RGB, RGBA images are supported in contig mode.
|
||||
Extra channels are only supported for grayscale images in planar mode.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
level : Default to None, i.e. not overwriting lossess and decodingspeed options.
|
||||
When < 0: Use lossless compression
|
||||
When in [0,1,2,3,4]: Sets the decoding speed tier for the provided options.
|
||||
Minimum is 0 (slowest to decode, best quality/density), and maximum
|
||||
is 4 (fastest to decode, at the cost of some quality/density).
|
||||
effort : Default to 3.
|
||||
Sets encoder effort/speed level without affecting decoding speed.
|
||||
Valid values are, from faster to slower speed: 1:lightning 2:thunder
|
||||
3:falcon 4:cheetah 5:hare 6:wombat 7:squirrel 8:kitten 9:tortoise.
|
||||
Speed: lightning, thunder, falcon, cheetah, hare, wombat, squirrel, kitten, tortoise
|
||||
control the encoder effort in ascending order.
|
||||
This also affects memory usage: using lower effort will typically reduce memory
|
||||
consumption during encoding.
|
||||
lightning and thunder are fast modes useful for lossless mode (modular).
|
||||
falcon disables all of the following tools.
|
||||
cheetah enables coefficient reordering, context clustering, and heuristics for selecting DCT sizes and quantization steps.
|
||||
hare enables Gaborish filtering, chroma from luma, and an initial estimate of quantization steps.
|
||||
wombat enables error diffusion quantization and full DCT size selection heuristics.
|
||||
squirrel (default) enables dots, patches, and spline detection, and full context clustering.
|
||||
kitten optimizes the adaptive quantization for a psychovisual metric.
|
||||
tortoise enables a more thorough adaptive quantization search.
|
||||
distance : Default to 1.0
|
||||
Sets the distance level for lossy compression: target max butteraugli distance,
|
||||
lower = higher quality. Range: 0 .. 15. 0.0 = mathematically lossless
|
||||
(however, use JxlEncoderSetFrameLossless instead to use true lossless,
|
||||
as setting distance to 0 alone is not the only requirement).
|
||||
1.0 = visually lossless. Recommended range: 0.5 .. 3.0.
|
||||
lossess : Default to False.
|
||||
Use lossess encoding.
|
||||
decodingspeed : Default to 0.
|
||||
Duplicate to level. [0,4]
|
||||
photometric : Return JxlColorSpace value.
|
||||
Default logic is quite complicated but works most of the time.
|
||||
Accepted value:
|
||||
int: [-1,3]
|
||||
str: ['RGB',
|
||||
'WHITEISZERO', 'MINISWHITE',
|
||||
'BLACKISZERO', 'MINISBLACK', 'GRAY',
|
||||
'XYB', 'KNOWN']
|
||||
planar : Enable multi-channel mode.
|
||||
Default to false.
|
||||
usecontainer :
|
||||
Forces the encoder to use the box-based container format (BMFF)
|
||||
even when not necessary.
|
||||
When using JxlEncoderUseBoxes, JxlEncoderStoreJPEGMetadata or
|
||||
JxlEncoderSetCodestreamLevel with level 10, the encoder will
|
||||
automatically also use the container format, it is not necessary
|
||||
to use JxlEncoderUseContainer for those use cases.
|
||||
By default this setting is disabled.
|
||||
index : Selectively decode frames for animation.
|
||||
Default to 0, decode all frames.
|
||||
When set to > 0, decode that frame index only.
|
||||
keeporientation :
|
||||
Enables or disables preserving of as-in-bitstream pixeldata orientation.
|
||||
Some images are encoded with an Orientation tag indicating that the
|
||||
decoder must perform a rotation and/or mirroring to the encoded image data.
|
||||
|
||||
If skip_reorientation is JXL_FALSE (the default): the decoder will apply
|
||||
the transformation from the orientation setting, hence rendering the image
|
||||
according to its specified intent. When producing a JxlBasicInfo, the decoder
|
||||
will always set the orientation field to JXL_ORIENT_IDENTITY (matching the
|
||||
returned pixel data) and also align xsize and ysize so that they correspond
|
||||
to the width and the height of the returned pixel data.
|
||||
|
||||
If skip_reorientation is JXL_TRUE: the decoder will skip applying the
|
||||
transformation from the orientation setting, returning the image in
|
||||
the as-in-bitstream pixeldata orientation. This may be faster to decode
|
||||
since the decoder doesnt have to apply the transformation, but can
|
||||
cause wrong display of the image if the orientation tag is not correctly
|
||||
taken into account by the user.
|
||||
|
||||
By default, this option is disabled, and the returned pixel data is
|
||||
re-oriented according to the images Orientation setting.
|
||||
threads : Default to 1.
|
||||
If <= 0, use all cores.
|
||||
If > 32, clipped to 32.
|
||||
"""
|
||||
|
||||
self.level = level
|
||||
self.effort = effort
|
||||
self.distance = distance
|
||||
self.lossless = bool(lossless)
|
||||
self.decodingspeed = decodingspeed
|
||||
self.photometric = photometric
|
||||
self.planar = planar
|
||||
self.usecontainer = usecontainer
|
||||
self.index = index
|
||||
self.keeporientation = keeporientation
|
||||
self.numthreads = numthreads
|
||||
|
||||
def encode(self, buf):
|
||||
# TODO: only squeeze all but last dim
|
||||
buf = protective_squeeze(numpy.asarray(buf))
|
||||
return imagecodecs.jpegxl_encode(
|
||||
buf,
|
||||
level=self.level,
|
||||
effort=self.effort,
|
||||
distance=self.distance,
|
||||
lossless=self.lossless,
|
||||
decodingspeed=self.decodingspeed,
|
||||
photometric=self.photometric,
|
||||
planar=self.planar,
|
||||
usecontainer=self.usecontainer,
|
||||
numthreads=self.numthreads,
|
||||
)
|
||||
|
||||
def decode(self, buf, out=None):
|
||||
return imagecodecs.jpegxl_decode(
|
||||
buf,
|
||||
index=self.index,
|
||||
keeporientation=self.keeporientation,
|
||||
numthreads=self.numthreads,
|
||||
out=out,
|
||||
)
|
||||
|
||||
|
||||
def _flat(out):
|
||||
"""Return numpy array as contiguous view of bytes if possible."""
|
||||
if out is None:
|
||||
return None
|
||||
view = memoryview(out)
|
||||
if view.readonly or not view.contiguous:
|
||||
return None
|
||||
return view.cast("B")
|
||||
|
||||
|
||||
def register_codecs(codecs=None, force=False, verbose=True):
|
||||
"""Register codecs in this module with numcodecs."""
|
||||
for name, cls in globals().items():
|
||||
if not hasattr(cls, "codec_id") or name == "Codec":
|
||||
continue
|
||||
if codecs is not None and cls.codec_id not in codecs:
|
||||
continue
|
||||
try:
|
||||
try: # noqa: SIM105
|
||||
get_codec({"id": cls.codec_id})
|
||||
except TypeError:
|
||||
# registered, but failed
|
||||
pass
|
||||
except ValueError:
|
||||
# not registered yet
|
||||
pass
|
||||
else:
|
||||
if not force:
|
||||
if verbose:
|
||||
log_warning(f"numcodec {cls.codec_id!r} already registered")
|
||||
continue
|
||||
if verbose:
|
||||
log_warning(f"replacing registered numcodec {cls.codec_id!r}")
|
||||
register_codec(cls)
|
||||
|
||||
|
||||
def log_warning(msg, *args, **kwargs):
|
||||
"""Log message with level WARNING."""
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).warning(msg, *args, **kwargs)
|
||||
@@ -1,233 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
|
||||
"""
|
||||
|
||||
import gc
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
|
||||
def get_cameras(hdf5_data):
|
||||
# ignore depth channel, not currently handled
|
||||
# TODO(rcadene): add depth
|
||||
rgb_cameras = [key for key in hdf5_data["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||
return rgb_cameras
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
# only frames from simulation are uncompressed
|
||||
compressed_images = "sim" not in raw_dir.name
|
||||
|
||||
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
|
||||
assert len(hdf5_paths) != 0
|
||||
for hdf5_path in hdf5_paths:
|
||||
with h5py.File(hdf5_path, "r") as data:
|
||||
assert "/action" in data
|
||||
assert "/observations/qpos" in data
|
||||
|
||||
assert data["/action"].ndim == 2
|
||||
assert data["/observations/qpos"].ndim == 2
|
||||
|
||||
num_frames = data["/action"].shape[0]
|
||||
assert num_frames == data["/observations/qpos"].shape[0]
|
||||
|
||||
for camera in get_cameras(data):
|
||||
assert num_frames == data[f"/observations/images/{camera}"].shape[0]
|
||||
|
||||
if compressed_images:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
||||
else:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# only frames from simulation are uncompressed
|
||||
compressed_images = "sim" not in raw_dir.name
|
||||
|
||||
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||
num_episodes = len(hdf5_files)
|
||||
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx in tqdm.tqdm(ep_ids):
|
||||
ep_path = hdf5_files[ep_idx]
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
if "/observations/effort" in ep:
|
||||
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
for camera in get_cameras(ep):
|
||||
img_key = f"observation.images.{camera}"
|
||||
|
||||
if compressed_images:
|
||||
import cv2
|
||||
|
||||
# load one compressed image after the other in RAM and uncompress
|
||||
imgs_array = []
|
||||
for data in ep[f"/observations/images/{camera}"]:
|
||||
imgs_array.append(cv2.imdecode(data, 1))
|
||||
imgs_array = np.array(imgs_array)
|
||||
|
||||
else:
|
||||
# load all images in RAM
|
||||
imgs_array = ep[f"/observations/images/{camera}"][:]
|
||||
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
if "/observations/velocity" in ep:
|
||||
ep_dict["observation.velocity"] = velocity
|
||||
if "/observations/effort" in ep:
|
||||
ep_dict["observation.effort"] = effort
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["next.done"] = done
|
||||
# TODO(rcadene): add reward and success by computing them in sim
|
||||
|
||||
assert isinstance(ep_idx, int)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
gc.collect()
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
keys = [key for key in data_dict if "observation.images." in key]
|
||||
for key in keys:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 50
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -1,107 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains utilities to process raw data format of png images files recorded with capture_camera_feed.py
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
)
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
def check_format(raw_dir: Path) -> bool:
|
||||
image_paths = list(raw_dir.glob("frame_*.png"))
|
||||
if len(image_paths) == 0:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, fps: int, episodes: list[int] | None = None):
|
||||
if episodes is not None:
|
||||
# TODO(aliberts): add support for multi-episodes.
|
||||
raise NotImplementedError()
|
||||
|
||||
ep_dict = {}
|
||||
ep_idx = 0
|
||||
|
||||
image_paths = sorted(raw_dir.glob("frame_*.png"))
|
||||
num_frames = len(image_paths)
|
||||
|
||||
ep_dict["observation.image"] = [PILImage.open(x) for x in image_paths]
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
|
||||
ep_dicts = [ep_dict]
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
if video:
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
if video or episodes or encoding is not None:
|
||||
# TODO(aliberts): support this
|
||||
raise NotImplementedError
|
||||
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 30
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -1,233 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains utilities to process raw data format from dora-record
|
||||
"""
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
assert raw_dir.exists()
|
||||
|
||||
leader_file = list(raw_dir.glob("*.parquet"))
|
||||
if len(leader_file) == 0:
|
||||
raise ValueError(f"Missing parquet files in '{raw_dir}'")
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
|
||||
# select first camera in alphanumeric order
|
||||
reference_key = sorted(reference_files)[0].stem
|
||||
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
|
||||
reference_df = reference_df[["timestamp_utc", reference_key]]
|
||||
|
||||
# Merge all data stream using nearest backward strategy
|
||||
df = reference_df
|
||||
for path in raw_dir.glob("*.parquet"):
|
||||
key = path.stem # action or observation.state or ...
|
||||
if key == reference_key:
|
||||
continue
|
||||
if "failed_episode_index" in key:
|
||||
# TODO(rcadene): add support for removing episodes that are tagged as "failed"
|
||||
continue
|
||||
modality_df = pd.read_parquet(path)
|
||||
modality_df = modality_df[["timestamp_utc", key]]
|
||||
df = pd.merge_asof(
|
||||
df,
|
||||
modality_df,
|
||||
on="timestamp_utc",
|
||||
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
||||
# matching timestamps that are too far apart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
||||
# are too far apart.
|
||||
direction="nearest",
|
||||
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
|
||||
)
|
||||
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
|
||||
df = df[df["episode_index"] != -1]
|
||||
|
||||
image_keys = [key for key in df if "observation.images." in key]
|
||||
|
||||
def get_episode_index(row):
|
||||
episode_index_per_cam = {}
|
||||
for key in image_keys:
|
||||
path = row[key][0]["path"]
|
||||
match = re.search(r"_(\d{6}).mp4", path)
|
||||
if not match:
|
||||
raise ValueError(path)
|
||||
episode_index = int(match.group(1))
|
||||
episode_index_per_cam[key] = episode_index
|
||||
if len(set(episode_index_per_cam.values())) != 1:
|
||||
raise ValueError(
|
||||
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
||||
)
|
||||
return episode_index
|
||||
|
||||
df["episode_index"] = df.apply(get_episode_index, axis=1)
|
||||
|
||||
# dora only use arrays, so single values are encapsulated into a list
|
||||
df["frame_index"] = df.groupby("episode_index").cumcount()
|
||||
df = df.reset_index()
|
||||
df["index"] = df.index
|
||||
|
||||
# set 'next.done' to True for the last frame of each episode
|
||||
df["next.done"] = False
|
||||
df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True
|
||||
|
||||
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
|
||||
# each episode starts with timestamp 0 to match the ones from the video
|
||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
|
||||
|
||||
del df["timestamp_utc"]
|
||||
|
||||
# sanity check
|
||||
has_nan = df.isna().any().any()
|
||||
if has_nan:
|
||||
raise ValueError("Dataset contains Nan values.")
|
||||
|
||||
# sanity check episode indices go from 0 to n-1
|
||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||
if ep_ids != expected_ep_ids:
|
||||
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
||||
|
||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||
|
||||
# sanity check the video paths are well formatted
|
||||
for key in df:
|
||||
if "observation.images." not in key:
|
||||
continue
|
||||
for ep_idx in ep_ids:
|
||||
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
|
||||
if not video_path.exists():
|
||||
raise ValueError(f"Video file not found in {video_path}")
|
||||
|
||||
data_dict = {}
|
||||
for key in df:
|
||||
# is video frame
|
||||
if "observation.images." in key:
|
||||
# we need `[0] because dora only use arrays, so single values are encapsulated into a list.
|
||||
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
||||
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
||||
|
||||
# sanity check the video path is well formatted
|
||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
||||
if not video_path.exists():
|
||||
raise ValueError(f"Video file not found in {video_path}")
|
||||
# is number
|
||||
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
|
||||
data_dict[key] = torch.from_numpy(df[key].values)
|
||||
# is vector
|
||||
elif df[key].iloc[0].shape[0] > 1:
|
||||
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
|
||||
else:
|
||||
raise ValueError(key)
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
keys = [key for key in data_dict if "observation.images." in key]
|
||||
for key in keys:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 30
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
if encoding is not None:
|
||||
warnings.warn(
|
||||
"Video encoding is currently done outside of LeRobot for the dora_parquet format.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
|
||||
hf_dataset = to_hf_dataset(data_df, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = "unknown"
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -1,312 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
For all datasets in the RLDS format.
|
||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||
|
||||
NOTE: You need to install tensorflow and tensorflow_datasets before running this script.
|
||||
|
||||
Example:
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir /path/to/data/bridge_dataset/1.0.0/ \
|
||||
--repo-id your_hub/sampled_bridge_data_v2 \
|
||||
--raw-format rlds \
|
||||
--episodes 3 4 5 8 9
|
||||
|
||||
Exact dataset fps defined in openx/config.py, obtained from:
|
||||
https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0&range=R:R
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
np.set_printoptions(precision=2)
|
||||
|
||||
|
||||
def tf_to_torch(data):
|
||||
return torch.from_numpy(data.numpy())
|
||||
|
||||
|
||||
def tf_img_convert(img):
|
||||
if img.dtype == tf.string:
|
||||
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)
|
||||
elif img.dtype != tf.uint8:
|
||||
raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}")
|
||||
return img.numpy()
|
||||
|
||||
|
||||
def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict:
|
||||
"""
|
||||
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
|
||||
entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the
|
||||
trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
|
||||
|
||||
NOTE: adapted from DLimp library https://github.com/kvablack/dlimp/
|
||||
"""
|
||||
steps = traj.pop("steps")
|
||||
|
||||
traj_len = tf.shape(tf.nest.flatten(steps)[0])[0]
|
||||
|
||||
# broadcast metadata to the length of the trajectory
|
||||
metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj)
|
||||
|
||||
# put steps back in
|
||||
assert "traj_metadata" not in steps
|
||||
traj = {**steps, "traj_metadata": metadata}
|
||||
|
||||
assert "_len" not in traj
|
||||
assert "_traj_index" not in traj
|
||||
assert "_frame_index" not in traj
|
||||
traj["_len"] = tf.repeat(traj_len, traj_len)
|
||||
traj["_traj_index"] = tf.repeat(i, traj_len)
|
||||
traj["_frame_index"] = tf.range(traj_len)
|
||||
|
||||
return traj
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
raw_dir (Path): _description_
|
||||
videos_dir (Path): _description_
|
||||
fps (int): _description_
|
||||
video (bool): _description_
|
||||
episodes (list[int] | None, optional): _description_. Defaults to None.
|
||||
"""
|
||||
ds_builder = tfds.builder_from_directory(str(raw_dir))
|
||||
dataset = ds_builder.as_dataset(
|
||||
split="all",
|
||||
decoders={"steps": tfds.decode.SkipDecoding()},
|
||||
)
|
||||
|
||||
dataset_info = ds_builder.info
|
||||
print("dataset_info: ", dataset_info)
|
||||
|
||||
ds_length = len(dataset)
|
||||
dataset = dataset.take(ds_length)
|
||||
# "flatten" the dataset as such we can apply trajectory level map() easily
|
||||
# each [obs][key] has a shape of (frame_size, ...)
|
||||
dataset = dataset.enumerate().map(_broadcast_metadata_rlds)
|
||||
|
||||
# we will apply the standardization transform if the dataset_name is provided
|
||||
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
|
||||
# search for 'image' keys in the observations
|
||||
image_keys = []
|
||||
state_keys = []
|
||||
observation_info = dataset_info.features["steps"]["observation"]
|
||||
for key in observation_info:
|
||||
# check whether the key is for an image or a vector observation
|
||||
if len(observation_info[key].shape) == 3:
|
||||
# only adding uint8 images discards depth images
|
||||
if observation_info[key].dtype == tf.uint8:
|
||||
image_keys.append(key)
|
||||
else:
|
||||
state_keys.append(key)
|
||||
|
||||
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
|
||||
|
||||
print(" - image_keys: ", image_keys)
|
||||
print(" - lang_key: ", lang_key)
|
||||
|
||||
it = iter(dataset)
|
||||
|
||||
ep_dicts = []
|
||||
# Init temp path to save ep_dicts in case of crash
|
||||
tmp_ep_dicts_dir = videos_dir.parent.joinpath("ep_dicts")
|
||||
tmp_ep_dicts_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# check if ep_dicts have already been saved in /tmp
|
||||
starting_ep_idx = 0
|
||||
saved_ep_dicts = [ep.__str__() for ep in tmp_ep_dicts_dir.iterdir()]
|
||||
if len(saved_ep_dicts) > 0:
|
||||
saved_ep_dicts.sort()
|
||||
# get last ep_idx number
|
||||
starting_ep_idx = int(saved_ep_dicts[-1][-13:-3]) + 1
|
||||
for i in range(starting_ep_idx):
|
||||
episode = next(it)
|
||||
ep_dicts.append(torch.load(saved_ep_dicts[i]))
|
||||
|
||||
# if we user specified episodes, skip the ones not in the list
|
||||
if episodes is not None:
|
||||
if ds_length == 0:
|
||||
raise ValueError("No episodes found.")
|
||||
# convert episodes index to sorted list
|
||||
episodes = sorted(episodes)
|
||||
|
||||
for ep_idx in tqdm.tqdm(range(starting_ep_idx, ds_length)):
|
||||
episode = next(it)
|
||||
|
||||
# if user specified episodes, skip the ones not in the list
|
||||
if episodes is not None:
|
||||
if len(episodes) == 0:
|
||||
break
|
||||
if ep_idx == episodes[0]:
|
||||
# process this episode
|
||||
print(" selecting episode idx: ", ep_idx)
|
||||
episodes.pop(0)
|
||||
else:
|
||||
continue # skip
|
||||
|
||||
num_frames = episode["action"].shape[0]
|
||||
|
||||
ep_dict = {}
|
||||
for key in state_keys:
|
||||
ep_dict[f"observation.{key}"] = tf_to_torch(episode["observation"][key])
|
||||
|
||||
ep_dict["action"] = tf_to_torch(episode["action"])
|
||||
ep_dict["next.reward"] = tf_to_torch(episode["reward"]).float()
|
||||
ep_dict["next.done"] = tf_to_torch(episode["is_last"])
|
||||
ep_dict["is_terminal"] = tf_to_torch(episode["is_terminal"])
|
||||
ep_dict["is_first"] = tf_to_torch(episode["is_first"])
|
||||
ep_dict["discount"] = tf_to_torch(episode["discount"])
|
||||
|
||||
# If lang_key is present, convert the entire tensor at once
|
||||
if lang_key is not None:
|
||||
ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]]
|
||||
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
|
||||
image_array_dict = {key: [] for key in image_keys}
|
||||
|
||||
for im_key in image_keys:
|
||||
imgs = episode["observation"][im_key]
|
||||
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
|
||||
|
||||
# loop through all cameras
|
||||
for im_key in image_keys:
|
||||
img_key = f"observation.images.{im_key}"
|
||||
imgs_array = image_array_dict[im_key]
|
||||
imgs_array = np.array(imgs_array)
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
path_ep_dict = tmp_ep_dicts_dir.joinpath(
|
||||
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
|
||||
)
|
||||
torch.save(ep_dict, path_ep_dict)
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
for key in data_dict:
|
||||
# check if vector state obs
|
||||
if key.startswith("observation.") and "observation.images." not in key:
|
||||
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
||||
# check if image obs
|
||||
elif "observation.images." in key:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
|
||||
if "language_instruction" in data_dict:
|
||||
features["language_instruction"] = Value(dtype="string", id=None)
|
||||
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
||||
features["is_terminal"] = Value(dtype="bool", id=None)
|
||||
features["is_first"] = Value(dtype="bool", id=None)
|
||||
features["discount"] = Value(dtype="float32", id=None)
|
||||
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.reward"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -1,275 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
|
||||
def check_format(raw_dir):
|
||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
zarr_data = zarr.open(zarr_path, mode="r")
|
||||
|
||||
required_datasets = {
|
||||
"data/action",
|
||||
"data/img",
|
||||
"data/keypoint",
|
||||
"data/n_contacts",
|
||||
"data/state",
|
||||
"meta/episode_ends",
|
||||
}
|
||||
for dataset in required_datasets:
|
||||
assert dataset in zarr_data
|
||||
nb_frames = zarr_data["data/img"].shape[0]
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
keypoints_instead_of_image: bool = False,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
# as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174
|
||||
success_threshold = 0.95 # 95% coverage,
|
||||
|
||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
|
||||
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
|
||||
assert len(
|
||||
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO(rcadene): verify that goal pose is expected to be fixed
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(zarr_data["img"]) # b h w c
|
||||
states = torch.from_numpy(zarr_data["state"])
|
||||
actions = torch.from_numpy(zarr_data["action"])
|
||||
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx = 0
|
||||
for to_idx in zarr_data.meta["episode_ends"]:
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
num_episodes = len(from_ids)
|
||||
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
# sanity check
|
||||
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
|
||||
|
||||
# get image
|
||||
if not keypoints_instead_of_image:
|
||||
image = imgs[from_idx:to_idx]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
# get state
|
||||
state = states[from_idx:to_idx]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
# get reward, success, done, and (maybe) keypoints
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
if keypoints_instead_of_image:
|
||||
keypoints = torch.zeros(num_frames, 16) # 8 keypoints each with 2 coords
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||
success[i] = coverage > success_threshold
|
||||
if keypoints_instead_of_image:
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
if not keypoints_instead_of_image:
|
||||
imgs_array = [x.numpy() for x in image]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = agent_pos
|
||||
if keypoints_instead_of_image:
|
||||
ep_dict["observation.environment_state"] = keypoints
|
||||
ep_dict["action"] = actions[from_idx:to_idx]
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
# ep_dict["next.observation.image"] = image[1:],
|
||||
# ep_dict["next.observation.state"] = agent_pos[1:],
|
||||
# TODO(rcadene)] = verify that reward and done are aligned with image and agent_pos
|
||||
ep_dict["next.reward"] = torch.cat([reward[1:], reward[[-1]]])
|
||||
ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]])
|
||||
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
||||
ep_dicts.append(ep_dict)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
|
||||
features = {}
|
||||
|
||||
if not keypoints_instead_of_image:
|
||||
if video:
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if keypoints_instead_of_image:
|
||||
features["observation.environment_state"] = Sequence(
|
||||
length=data_dict["observation.environment_state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.reward"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["next.success"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
|
||||
# with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
|
||||
keypoints_instead_of_image = False
|
||||
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 10
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video if not keypoints_instead_of_image else 0,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -1,234 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
zarr_data = zarr.open(zarr_path, mode="r")
|
||||
|
||||
required_datasets = {
|
||||
"data/robot0_demo_end_pose",
|
||||
"data/robot0_demo_start_pose",
|
||||
"data/robot0_eef_pos",
|
||||
"data/robot0_eef_rot_axis_angle",
|
||||
"data/robot0_gripper_width",
|
||||
"meta/episode_ends",
|
||||
"data/camera0_rgb",
|
||||
}
|
||||
for dataset in required_datasets:
|
||||
if dataset not in zarr_data:
|
||||
return False
|
||||
|
||||
# mandatory to access zarr_data
|
||||
register_codecs()
|
||||
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
zarr_data = zarr.open(zarr_path, mode="r")
|
||||
|
||||
# We process the image data separately because it is too large to fit in memory
|
||||
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
|
||||
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
|
||||
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
|
||||
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
|
||||
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
|
||||
|
||||
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
|
||||
states = torch.cat([states_pos, gripper_width], dim=1)
|
||||
|
||||
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||
num_episodes = episode_ends.shape[0]
|
||||
|
||||
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||
episode_ends = torch.from_numpy(episode_ends)
|
||||
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx = 0
|
||||
for to_idx in episode_ends:
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
ep_dicts_dir = videos_dir / "ep_dicts"
|
||||
ep_dicts_dir.mkdir(exist_ok=True, parents=True)
|
||||
ep_dicts = []
|
||||
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
ep_dict_path = ep_dicts_dir / f"{ep_idx}"
|
||||
if not ep_dict_path.is_file():
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
# TODO(rcadene): save temporary images of the episode?
|
||||
|
||||
state = states[from_idx:to_idx]
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
# load 57MB of images in RAM (400x224x224x3 uint8)
|
||||
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
if not video_path.is_file():
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||
torch.save(ep_dict, ep_dict_path)
|
||||
else:
|
||||
ep_dict = torch.load(ep_dict_path)
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
features = {}
|
||||
|
||||
if video:
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
features["episode_data_index_from"] = Value(dtype="int64", id=None)
|
||||
features["episode_data_index_to"] = Value(dtype="int64", id=None)
|
||||
# `start_pos` and `end_pos` respectively represent the positions of the end-effector
|
||||
# at the beginning and the end of the episode.
|
||||
# `gripper_width` indicates the distance between the grippers, and this value is included
|
||||
# in the state vector, which comprises the concatenation of the end-effector position
|
||||
# and gripper width.
|
||||
features["end_pose"] = Sequence(
|
||||
length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["start_pos"] = Sequence(
|
||||
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["gripper_width"] = Sequence(
|
||||
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
# For umi cup in the wild: https://arxiv.org/pdf/2402.10329#table.caption.16
|
||||
fps = 10
|
||||
|
||||
if not video:
|
||||
logging.warning(
|
||||
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
||||
)
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -1,200 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
||||
|
||||
import pickle
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
|
||||
def check_format(raw_dir):
|
||||
keys = {"actions", "rewards", "dones"}
|
||||
nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
|
||||
|
||||
xarm_files = list(raw_dir.glob("*.pkl"))
|
||||
assert len(xarm_files) > 0
|
||||
|
||||
with open(xarm_files[0], "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
||||
assert isinstance(dataset_dict, dict)
|
||||
assert all(k in dataset_dict for k in keys)
|
||||
|
||||
# Check for consistent lengths in nested keys
|
||||
expected_len = len(dataset_dict["actions"])
|
||||
assert all(len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict)
|
||||
|
||||
for key, subkeys in nested_keys.items():
|
||||
nested_dict = dataset_dict.get(key, {})
|
||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
pkl_path = raw_dir / "buffer.pkl"
|
||||
|
||||
with open(pkl_path, "rb") as f:
|
||||
pkl_data = pickle.load(f)
|
||||
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx, to_idx = 0, 0
|
||||
for done in pkl_data["dones"]:
|
||||
to_idx += 1
|
||||
if not done:
|
||||
continue
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
num_episodes = len(from_ids)
|
||||
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx])
|
||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||
state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx])
|
||||
action = torch.tensor(pkl_data["actions"][from_idx:to_idx])
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx])
|
||||
# next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx])
|
||||
next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx])
|
||||
next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
imgs_array = [x.numpy() for x in image]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
# ep_dict["next.observation.image"] = next_image
|
||||
# ep_dict["next.observation.state"] = next_state
|
||||
ep_dict["next.reward"] = next_reward
|
||||
ep_dict["next.done"] = next_done
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
features = {}
|
||||
|
||||
if video:
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.reward"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
# TODO(rcadene): add success
|
||||
# features["next.success"] = Value(dtype='bool', id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 15
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -427,7 +427,7 @@ def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
return {ep_idx: stats for ep_idx in episodes}
|
||||
return dict.fromkeys(episodes, stats)
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
|
||||
@@ -479,7 +479,7 @@ def convert_dataset(
|
||||
|
||||
# Tasks
|
||||
if single_task:
|
||||
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
|
||||
tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_path:
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
@@ -45,7 +59,7 @@ def convert_dataset(
|
||||
num_workers: int = 4,
|
||||
):
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
dataset = LeRobotDataset(repo_id) #, revision=V20) #, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / LEGACY_EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / LEGACY_EPISODES_STATS_PATH).unlink()
|
||||
@@ -57,21 +71,21 @@ def convert_dataset(
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
#dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
# hub_api = HfApi()
|
||||
# if hub_api.file_exists(
|
||||
# repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
# ):
|
||||
# hub_api.delete_file(
|
||||
# path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
# )
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
# hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
@@ -29,6 +30,46 @@ from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
return "torchcodec"
|
||||
else:
|
||||
logging.warning(
|
||||
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||
)
|
||||
return "pyav"
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
|
||||
Args:
|
||||
video_path (Path): Path to the video file.
|
||||
timestamps (list[float]): List of timestamps to extract frames.
|
||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames.
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_codec()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
@@ -127,6 +168,81 @@ def decode_video_frames_torchvision(
|
||||
return closest_frames
|
||||
|
||||
|
||||
def decode_video_frames_torchcodec(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
device: str = "cpu",
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||
|
||||
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
"""
|
||||
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
raise ImportError("torchcodec is required but not available.")
|
||||
|
||||
# initialize video decoder
|
||||
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
# get metadata for frame information
|
||||
metadata = decoder.metadata
|
||||
average_fps = metadata.average_fps
|
||||
|
||||
# convert timestamps to frame indices
|
||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||
|
||||
# retrieve frames based on indices
|
||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
|
||||
# compute distances between each query timestamp and loaded timestamps
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range (channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
@@ -141,6 +257,7 @@ def encode_video_frames(
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ffmpeg_args = OrderedDict(
|
||||
|
||||
@@ -1 +1,15 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@@ -13,7 +13,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@@ -86,3 +90,38 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
policy_features[policy_key] = feature
|
||||
|
||||
return policy_features
|
||||
|
||||
|
||||
def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
|
||||
first_type = type(env.envs[0]) # Get type of first env
|
||||
return all(type(e) is first_type for e in env.envs) # Fast type check
|
||||
|
||||
|
||||
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
||||
|
||||
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
|
||||
warnings.warn(
|
||||
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not are_all_envs_same_type(env):
|
||||
warnings.warn(
|
||||
"The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||
if hasattr(env.envs[0], "task_description"):
|
||||
observation["task"] = env.call("task_description")
|
||||
elif hasattr(env.envs[0], "task"):
|
||||
observation["task"] = env.call("task")
|
||||
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
|
||||
num_envs = observation[list(observation.keys())[0]].shape[0]
|
||||
observation["task"] = ["" for _ in range(num_envs)]
|
||||
return observation
|
||||
|
||||
@@ -1 +1,15 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .optimizers import OptimizerConfig as OptimizerConfig
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
|
||||
@@ -119,9 +119,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
@@ -149,9 +147,8 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
@@ -413,11 +410,10 @@ class ACT(nn.Module):
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
batch_size = (
|
||||
batch["observation.images"]
|
||||
if "observation.images" in batch
|
||||
else batch["observation.environment_state"]
|
||||
).shape[0]
|
||||
if "observation.images" in batch:
|
||||
batch_size = batch["observation.images"][0].shape[0]
|
||||
else:
|
||||
batch_size = batch["observation.environment_state"].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch:
|
||||
@@ -490,20 +486,21 @@ class ACT(nn.Module):
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
|
||||
for cam_index in range(batch["observation.images"].shape[-4]):
|
||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
||||
# buffer
|
||||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
for img in batch["observation.images"]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||
|
||||
# Rearrange features to (sequence, batch, dim).
|
||||
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
|
||||
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
|
||||
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
|
||||
# and move to (sequence, batch, dim).
|
||||
all_cam_features = torch.cat(all_cam_features, axis=-1)
|
||||
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
|
||||
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
|
||||
|
||||
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
|
||||
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
|
||||
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
@@ -26,6 +25,7 @@ from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
@@ -55,6 +55,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
elif name == "pi0fast":
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -70,13 +74,14 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
device: str | torch.device,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
env_cfg: EnvConfig | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
@@ -88,7 +93,6 @@ def make_policy(
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
||||
be loaded with the weights from that path.
|
||||
device (str): the device to load the policy onto.
|
||||
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
||||
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
||||
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
||||
@@ -96,7 +100,7 @@ def make_policy(
|
||||
|
||||
Raises:
|
||||
ValueError: Either ds_meta or env and env_cfg must be provided.
|
||||
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
|
||||
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
|
||||
|
||||
Returns:
|
||||
PreTrainedPolicy: _description_
|
||||
@@ -111,7 +115,7 @@ def make_policy(
|
||||
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
||||
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
||||
# slower than running natively on MPS.
|
||||
if cfg.type == "vqbet" and str(device) == "mps":
|
||||
if cfg.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
@@ -145,7 +149,7 @@ def make_policy(
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(**kwargs)
|
||||
|
||||
policy.to(device)
|
||||
policy.to(cfg.device)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
@@ -76,6 +90,7 @@ class PI0Config(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# TODO(Steven): Validate device and amp? in all policy configs?
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
@@ -31,7 +45,7 @@ def main():
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, device, ds_meta=dataset.meta)
|
||||
policy = make_policy(cfg, ds_meta=dataset.meta)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
@@ -87,7 +101,7 @@ def main():
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, device, dataset_meta)
|
||||
policy = make_policy(cfg, dataset_meta)
|
||||
|
||||
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
||||
# loss_dict["loss"].backward()
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import GemmaConfig, PaliGemmaConfig
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Convert pi0 parameters from Jax to Pytorch
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from packaging.version import Version
|
||||
|
||||
@@ -313,7 +313,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
actions_is_pad = batch.get("action_is_pad")
|
||||
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
136
lerobot/common/policies/pi0fast/configuration_pi0fast.py
Normal file
136
lerobot/common/policies/pi0fast/configuration_pi0fast.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0fast")
|
||||
@dataclass
|
||||
class PI0FASTConfig(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 10
|
||||
n_action_steps: int = 5
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32 # 32
|
||||
max_action_dim: int = 32 # 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
||||
interpolate_like_pi: bool = False
|
||||
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Projector
|
||||
proj_width: int = 1024
|
||||
|
||||
# Decoding
|
||||
max_decoding_steps: int = 256
|
||||
fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
max_input_seq_len: int = 256 # 512
|
||||
|
||||
# Utils
|
||||
use_cache: bool = True
|
||||
|
||||
# Frozen parameters
|
||||
freeze_vision_encoder: bool = True
|
||||
freeze_lm_head: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
checkpoint_path: str = None
|
||||
|
||||
padding_side: str = "right"
|
||||
|
||||
precision: str = "bfloat16"
|
||||
grad_clip_norm: float = 1
|
||||
|
||||
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
|
||||
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
|
||||
relaxed_action_decoding: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
973
lerobot/common/policies/pi0fast/modeling_pi0fast.py
Normal file
973
lerobot/common/policies/pi0fast/modeling_pi0fast.py
Normal file
@@ -0,0 +1,973 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
|
||||
|
||||
[Paper](https://arxiv.org/abs/2501.09747)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
|
||||
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/pi0fast_base \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of training the pi0+FAST neural network with from scratch:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=pi0fast \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from scipy.fft import idct
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
|
||||
from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
PRECISION = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def safe_arcsin(value):
|
||||
# This ensures that the input stays within
|
||||
# [−1,1] to avoid invalid values for arcsin
|
||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
def aloha_gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class PI0FASTPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = PI0FASTConfig
|
||||
name = "pi0fast"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0FASTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FAST(config)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
state[:, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
return state
|
||||
|
||||
def _pi_aloha_encode_actions(self, actions):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# Flip the joints again.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.model.generate_actions(batch)
|
||||
|
||||
actions = actions[:, : self.config.n_action_steps]
|
||||
|
||||
original_action_dim = self.config.action_feature.shape[
|
||||
0
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss_dict = self.model.forward(batch)
|
||||
return loss_dict["loss"], loss_dict
|
||||
|
||||
|
||||
def block_causal_update_causal_mask(
|
||||
attention_mask,
|
||||
token_type_ids=None,
|
||||
past_key_values=None,
|
||||
cache_position=None,
|
||||
input_tensor=None,
|
||||
attn_implementation: str = "eager",
|
||||
dtype: torch.dtype = "float32",
|
||||
):
|
||||
"""
|
||||
Update the causal mask during training and generation. It can be customized to different attention masks.
|
||||
"""
|
||||
if attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
if input_tensor is None:
|
||||
input_tensor = attention_mask
|
||||
|
||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||
|
||||
if using_static_cache or isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
# Handle precomputed attention masks
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
return attention_mask
|
||||
|
||||
# Causal mask initialization
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
|
||||
# Standard causal masking (triu ensures tokens can only attend to past)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
|
||||
# Apply block causal mask
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.to(causal_mask.device).bool()
|
||||
cumsum = torch.cumsum(token_type_ids, dim=1)
|
||||
block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
|
||||
# Combine causal_mask with block-wise attention mask
|
||||
causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
|
||||
causal_mask = causal_mask[:, None, :, :]
|
||||
else:
|
||||
# Apply past cache position constraint
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
else:
|
||||
# Apply past cache position constraint
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
|
||||
mask_length = attention_mask.shape[-1]
|
||||
|
||||
# Apply padding mask
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
# self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=None,
|
||||
labels=None,
|
||||
self=None,
|
||||
**kwargs,
|
||||
):
|
||||
# create block causal attention
|
||||
if cache_position[0] > 0 and input_ids.shape[1] > 0:
|
||||
input_tensor = input_ids[:, -1:]
|
||||
new_positions = (
|
||||
torch.ones(
|
||||
(position_ids.shape[0], input_ids.shape[1]),
|
||||
dtype=position_ids.dtype,
|
||||
device=position_ids.device,
|
||||
).cumsum(-1)
|
||||
+ position_ids[:, -1:]
|
||||
)
|
||||
position_ids = torch.cat([position_ids, new_positions], dim=-1)
|
||||
else:
|
||||
input_tensor = inputs_embeds
|
||||
attention_mask = block_causal_update_causal_mask(
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
input_tensor=input_tensor,
|
||||
token_type_ids=token_type_ids,
|
||||
dtype=self.dtype,
|
||||
attn_implementation=self.config.text_config._attn_implementation,
|
||||
)
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
use_cache=use_cache,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
token_type_ids=token_type_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Position_ids in Paligemma are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
class PI0FAST(nn.Module):
|
||||
def __init__(self, config: PI0FASTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# TODO: move tokenizers in Policy
|
||||
fast_tokenizer_path = "physical-intelligence/fast"
|
||||
pi0_paligemma_path = "google/paligemma-3b-pt-224"
|
||||
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
|
||||
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
|
||||
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
|
||||
self.fast_skip_tokens = self.config.fast_skip_tokens
|
||||
self.max_input_seq_len = self.config.max_input_seq_len
|
||||
self.action_horizon = self.config.chunk_size
|
||||
self.action_dim = self.config.action_feature.shape[
|
||||
0
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
precision = config.precision
|
||||
torch_precision = PRECISION.get(precision, torch.float32)
|
||||
self.pad_token_id = (
|
||||
self.paligemma_tokenizer.pad_token_id
|
||||
if hasattr(self.paligemma_tokenizer, "pad_token_id")
|
||||
else self.paligemma_tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
paligemma_config = CONFIG_MAPPING["paligemma"](
|
||||
transformers_version="4.48.1",
|
||||
_vocab_size=257152,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
hidden_size=2048,
|
||||
image_token_index=257152,
|
||||
model_type="paligemma",
|
||||
pad_token_id=0,
|
||||
projection_dim=2048,
|
||||
text_config={
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 16384,
|
||||
"model_type": "gemma",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 18,
|
||||
"num_image_tokens": 256,
|
||||
"num_key_value_heads": 1,
|
||||
"torch_dtype": precision,
|
||||
"vocab_size": 257152,
|
||||
"_attn_implementation": "eager",
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"num_image_tokens": 256,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 2048,
|
||||
"projector_hidden_act": "gelu_pytorch_tanh",
|
||||
"torch_dtype": precision,
|
||||
"vision_use_head": False,
|
||||
},
|
||||
)
|
||||
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
|
||||
|
||||
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
||||
prepare_inputs_for_generation, self=self.pi0_paligemma
|
||||
)
|
||||
# change important stuff in bf16
|
||||
params_to_change_dtype = [
|
||||
"language_model",
|
||||
"vision_tower",
|
||||
"multi_modal",
|
||||
]
|
||||
for name, param in self.pi0_paligemma.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch_precision)
|
||||
self.set_requires_grad()
|
||||
self.image_keys = self.config.image_features.keys()
|
||||
self.ignore_index = self.pi0_paligemma.config.ignore_index
|
||||
self.padding_side = self.config.padding_side
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.pi0_paligemma.vision_tower.eval()
|
||||
for params in self.pi0_paligemma.vision_tower.parameters():
|
||||
params.requires_grad = False
|
||||
# To avoid unused params issue with distributed training
|
||||
if self.config.freeze_lm_head:
|
||||
for name, params in self.pi0_paligemma.named_parameters():
|
||||
if "embed_tokens" in name: # lm heads and embedding layer are tied
|
||||
params.requires_grad = False
|
||||
|
||||
def embed_tokens(self, tokens: torch.Tensor):
|
||||
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||
return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Preprocess LeRobot batch into Pi0 inputs"""
|
||||
images = []
|
||||
img_masks = []
|
||||
present_img_keys = [key for key in self.image_keys if key in batch]
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
)
|
||||
|
||||
# Preprocess image features present in the batch
|
||||
num_empty_cameras = 0
|
||||
for key in self.image_keys:
|
||||
if key in present_img_keys:
|
||||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(
|
||||
img,
|
||||
*self.config.resize_imgs_with_padding,
|
||||
pad_value=0,
|
||||
interpolate_like_pi=self.config.interpolate_like_pi,
|
||||
)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
else:
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
continue
|
||||
img = torch.ones_like(img) * -1
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
num_empty_cameras += 1
|
||||
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
|
||||
mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
|
||||
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
|
||||
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||
return out
|
||||
|
||||
def fast_tokenizer_wrapper(self, actions_norm):
|
||||
"""
|
||||
A wrapper for self.fast_tokenizer that ensures batch processing,
|
||||
conversion to PyTorch tensors, and returns a dictionary without padding.
|
||||
"""
|
||||
batch_tokens = self.fast_tokenizer(actions_norm)
|
||||
fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
|
||||
|
||||
return fast_out
|
||||
|
||||
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
|
||||
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
|
||||
# Compute cumulative sum mask
|
||||
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
|
||||
# Suffix block (everything after prefix_len)
|
||||
suffix_mask = cumsum_mask > prefix_len
|
||||
token_type_ids = suffix_mask
|
||||
return token_type_ids
|
||||
|
||||
def create_input_tokens(self, state, lang_text, actions=None):
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
|
||||
discretized = torch.bucketize(state, bins) - 1
|
||||
discretized = discretized[:, :32]
|
||||
|
||||
prefix_texts = []
|
||||
state_text = []
|
||||
for txt, disc in zip(lang_text, discretized, strict=False):
|
||||
cleaned = txt.lower().strip().replace("_", " ")
|
||||
state_str = " ".join(str(val.item()) for val in disc)
|
||||
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
|
||||
state_text.append(f"State: {state_str};\n")
|
||||
|
||||
prefix_out = self.paligemma_tokenizer(
|
||||
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
|
||||
)
|
||||
prefix_ids = prefix_out["input_ids"].to(device)
|
||||
prefix_mask = prefix_out["attention_mask"].to(device)
|
||||
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
|
||||
|
||||
if actions is not None:
|
||||
actions_norm = self.normalize_actions(actions)
|
||||
actions_pad = F.pad(
|
||||
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
|
||||
)[:, :, : self.config.max_action_dim]
|
||||
fast_out = self.fast_tokenizer_wrapper(
|
||||
actions_pad.cpu(),
|
||||
)
|
||||
act_ids = fast_out["input_ids"]
|
||||
act_mask = fast_out["attention_mask"].to(device)
|
||||
|
||||
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
|
||||
# Replace action with 0 to pad tokens
|
||||
act_ids = torch.where(
|
||||
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
||||
self.pad_token_id,
|
||||
act_ids,
|
||||
)
|
||||
|
||||
eos_token = torch.tensor(
|
||||
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
|
||||
).expand(bsize, -1)
|
||||
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
|
||||
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
|
||||
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
|
||||
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
|
||||
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
||||
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
||||
act_mask = act_mask.to(device)
|
||||
else:
|
||||
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
|
||||
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
|
||||
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
|
||||
|
||||
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
|
||||
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
|
||||
|
||||
# Use tokenizer pad function
|
||||
padded_output = self.paligemma_tokenizer.pad(
|
||||
batch_inputs, padding="longest", max_length=180, return_tensors="pt"
|
||||
)
|
||||
padded_mask = padded_output["attention_mask"]
|
||||
|
||||
# define tensor of padding lengths
|
||||
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
|
||||
|
||||
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
|
||||
|
||||
padded_output["padded_mask"] = padded_output.pop("attention_mask")
|
||||
padded_output["attention_mask"] = att_mask
|
||||
# loss is computed not on prefix, and not on padding
|
||||
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
|
||||
padded_output["token_type_ids"] = token_type_ids
|
||||
return padded_output
|
||||
|
||||
def shift_padding_side(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
ar_mask: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
loss_mask: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
token_type_ids: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> tuple[torch.Tensor]:
|
||||
if padding_side not in ["right", "left"]:
|
||||
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
|
||||
|
||||
new_tokens = torch.empty_like(tokens)
|
||||
new_ar_masks = torch.empty_like(ar_mask)
|
||||
new_padding_mask = torch.empty_like(padding_mask)
|
||||
new_loss_mask = torch.empty_like(loss_mask)
|
||||
new_targets = torch.empty_like(targets)
|
||||
new_token_type_ids = torch.empty_like(token_type_ids)
|
||||
batch_size = tokens.shape[0]
|
||||
for i in range(batch_size):
|
||||
padding_indices = torch.where(padding_mask[i] == 0)[0]
|
||||
non_padding_indices = torch.where(padding_mask[i] == 1)[0]
|
||||
if padding_side == "left":
|
||||
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
|
||||
else:
|
||||
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
|
||||
new_tokens[i] = tokens[i].index_select(0, new_indices)
|
||||
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
|
||||
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
|
||||
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
|
||||
new_targets[i] = targets[i].index_select(0, new_indices)
|
||||
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
|
||||
|
||||
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]):
|
||||
device = batch[OBS_ROBOT].device
|
||||
# TODO: keep like this or move to the policy .forward
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
|
||||
padded_outs = self.create_input_tokens(
|
||||
state=batch[OBS_ROBOT],
|
||||
lang_text=batch["task"],
|
||||
actions=batch[ACTION],
|
||||
)
|
||||
|
||||
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
images,
|
||||
img_masks,
|
||||
padded_outs["input_ids"],
|
||||
padded_outs["padded_mask"],
|
||||
padded_outs["attention_mask"],
|
||||
padded_outs["loss_mask"],
|
||||
padded_outs["token_type_ids"],
|
||||
padding_side=self.padding_side,
|
||||
)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
||||
past_seen_tokens = 0
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
|
||||
pad_masks = block_causal_update_causal_mask(
|
||||
attention_mask=pad_masks,
|
||||
past_key_values=None,
|
||||
cache_position=cache_position,
|
||||
input_tensor=embs,
|
||||
token_type_ids=token_type_ids,
|
||||
dtype=self.pi0_paligemma.dtype,
|
||||
attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
|
||||
)
|
||||
outputs = self.pi0_paligemma.forward(
|
||||
input_ids=None,
|
||||
token_type_ids=None,
|
||||
attention_mask=pad_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=embs,
|
||||
use_cache=False,
|
||||
labels=None,
|
||||
)
|
||||
|
||||
logits = outputs.logits
|
||||
|
||||
loss_fct = nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# Shift left for next-step prediction
|
||||
logits = logits[:, :-1, :]
|
||||
targets = targets[:, 1:].to(device) # Shift targets
|
||||
loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
|
||||
|
||||
# Compute per-token loss
|
||||
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
|
||||
|
||||
# Apply loss mask
|
||||
token_loss = token_loss * loss_mask.reshape(-1)
|
||||
|
||||
# Compute final loss
|
||||
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
|
||||
|
||||
# Return loss dictionary
|
||||
loss_dict = {"ce_loss": loss.item(), "loss": loss}
|
||||
return loss_dict
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
tokens: list[list[int]],
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
relaxed_decoding: bool = True,
|
||||
) -> np.array:
|
||||
"""
|
||||
Adapt original decoding in FAST to always return actions instead of zeros.
|
||||
"""
|
||||
self.time_horizon = (
|
||||
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
|
||||
)
|
||||
self.action_dim = (
|
||||
action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
|
||||
)
|
||||
|
||||
# Cache the time horizon and action dimension for the next call
|
||||
self.called_time_horizon = self.time_horizon
|
||||
self.called_action_dim = self.action_dim
|
||||
|
||||
assert self.time_horizon is not None and self.action_dim is not None, (
|
||||
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||
)
|
||||
|
||||
decoded_actions = []
|
||||
for token in tokens:
|
||||
try:
|
||||
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
|
||||
if relaxed_decoding:
|
||||
# Expected sequence length
|
||||
expected_seq_len = self.time_horizon * self.action_dim
|
||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||
# Apply truncation if too long
|
||||
if diff < 0:
|
||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
|
||||
# Apply padding if too short
|
||||
elif diff > 0:
|
||||
decoded_dct_coeff = np.pad(
|
||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
self.time_horizon,
|
||||
self.action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
|
||||
"""
|
||||
Extracts actions from predicted output tokens using the FAST model.
|
||||
|
||||
Args:
|
||||
tokens (torch.Tensor): The input tensor of tokenized outputs.
|
||||
action_horizon (int): The number of timesteps for actions.
|
||||
action_dim (int): The dimensionality of each action.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
|
||||
"""
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
||||
cleaned_tokens = [
|
||||
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
||||
for tokens_sequence in decoded_tokens
|
||||
]
|
||||
raw_action_tokens = [
|
||||
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
||||
for sample_tokens in cleaned_tokens
|
||||
] # something like this should be robust #looks good
|
||||
action_tokens = [
|
||||
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
|
||||
]
|
||||
# returns the tensor of decoded actions per sample in a list
|
||||
decoded_actions = [
|
||||
torch.tensor(
|
||||
self.decode_actions_with_fast(
|
||||
tok.tolist(),
|
||||
time_horizon=action_horizon,
|
||||
action_dim=action_dim,
|
||||
relaxed_decoding=self.config.relaxed_action_decoding,
|
||||
),
|
||||
device=tokens.device,
|
||||
).squeeze(0)
|
||||
for tok in action_tokens
|
||||
]
|
||||
|
||||
return torch.stack(
|
||||
decoded_actions,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor]):
|
||||
# TODO: keep like this or move to the policy .forward
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
|
||||
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
|
||||
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
images,
|
||||
img_masks,
|
||||
padded_outs["input_ids"],
|
||||
padded_outs["padded_mask"],
|
||||
padded_outs["attention_mask"],
|
||||
padded_outs["loss_mask"],
|
||||
padded_outs["token_type_ids"],
|
||||
padding_side="left",
|
||||
)
|
||||
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
||||
prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
output_tokens = self.pi0_paligemma.generate(
|
||||
input_ids=None,
|
||||
attention_mask=pad_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=embs,
|
||||
use_cache=self.config.use_cache,
|
||||
max_new_tokens=self.config.max_decoding_steps,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
|
||||
return actions
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.pi0_paligemma.get_image_features(image)
|
||||
|
||||
def embed_inputs(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
pad_mask,
|
||||
ar_mask,
|
||||
loss_mask,
|
||||
token_type_ids,
|
||||
padding_side: str = "right",
|
||||
):
|
||||
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
||||
# images are a list of same size
|
||||
# vectorizing everything!
|
||||
device = images[0].device
|
||||
image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
|
||||
all_images = torch.stack(images, dim=1).to(device)
|
||||
b, n, c, h, w = all_images.shape
|
||||
all_images = all_images.view(b * n, c, h, w)
|
||||
embedded = self.embed_image(all_images).to(device)
|
||||
b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
|
||||
m = b_n // b # Compute the number of images per sample dynamically
|
||||
|
||||
# Reshape dynamically
|
||||
embedded = embedded.view(b, m, p, image_embedding_dim)
|
||||
tokens_embs = self.embed_tokens(tokens.to(device))
|
||||
|
||||
img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
|
||||
num_img_emb = embedded.shape[2]
|
||||
img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
|
||||
img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
||||
|
||||
image_target_tokens = (
|
||||
torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
|
||||
).reshape(b, -1)
|
||||
image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
||||
|
||||
embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
|
||||
|
||||
embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
|
||||
pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
|
||||
att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
|
||||
loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
|
||||
targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
|
||||
token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
|
||||
|
||||
# Shift pad tokens to the left (.generate()) or right (.train())
|
||||
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
|
||||
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
|
||||
)
|
||||
|
||||
targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
|
||||
return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
|
||||
|
||||
|
||||
def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
|
||||
# assume no-op when width height fits already
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
cur_height, cur_width = img.shape[2:]
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
|
||||
if interpolate_like_pi:
|
||||
img = (img * 255.0).to(dtype=torch.uint8)
|
||||
img = img.permute(0, 2, 3, 1)
|
||||
original_device = img.device
|
||||
img = img.to(device="cpu").numpy()
|
||||
imgs = []
|
||||
for sub_img in img:
|
||||
sub_img = Image.fromarray(sub_img)
|
||||
resized_img = sub_img.resize((resized_width, resized_height), resample=2)
|
||||
resized_img = torch.from_numpy(np.array(resized_img))
|
||||
imgs.append(resized_img)
|
||||
img = torch.stack(imgs, dim=0)
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
|
||||
else:
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, int(height - resized_height))
|
||||
pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
@@ -73,7 +86,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
map_location: str = "cpu",
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
@@ -98,7 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
@@ -112,13 +124,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
policy.to(map_location)
|
||||
policy.to(config.device)
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
# When the action queue is depleted, populate it again by querying the policy.
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
|
||||
|
||||
# Remove the time dimensions as it is not handled yet.
|
||||
for key in batch:
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file contains utilities for recording frames from Intel Realsense cameras.
|
||||
"""
|
||||
@@ -34,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
|
||||
connected to the computer.
|
||||
"""
|
||||
if mock:
|
||||
import tests.mock_pyrealsense2 as rs
|
||||
import tests.cameras.mock_pyrealsense2 as rs
|
||||
else:
|
||||
import pyrealsense2 as rs
|
||||
|
||||
@@ -86,7 +100,7 @@ def save_images_from_cameras(
|
||||
serial_numbers = [cam["serial_number"] for cam in camera_infos]
|
||||
|
||||
if mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -100,7 +114,7 @@ def save_images_from_cameras(
|
||||
camera = IntelRealSenseCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
|
||||
@@ -210,9 +224,20 @@ class IntelRealSenseCamera:
|
||||
self.serial_number = self.find_serial_number_from_name(config.name)
|
||||
else:
|
||||
self.serial_number = config.serial_number
|
||||
|
||||
# Store the raw (capture) resolution from the config.
|
||||
self.capture_width = config.width
|
||||
self.capture_height = config.height
|
||||
|
||||
# If rotated by ±90, swap width and height.
|
||||
if config.rotation in [-90, 90]:
|
||||
self.width = config.height
|
||||
self.height = config.width
|
||||
else:
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.use_depth = config.use_depth
|
||||
@@ -228,11 +253,10 @@ class IntelRealSenseCamera:
|
||||
self.logs = {}
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# TODO(alibets): Do we keep original width/height or do we define them after rotation?
|
||||
self.rotation = None
|
||||
if config.rotation == -90:
|
||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
@@ -263,22 +287,26 @@ class IntelRealSenseCamera:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_pyrealsense2 as rs
|
||||
import tests.cameras.mock_pyrealsense2 as rs
|
||||
else:
|
||||
import pyrealsense2 as rs
|
||||
|
||||
config = rs.config()
|
||||
config.enable_device(str(self.serial_number))
|
||||
|
||||
if self.fps and self.width and self.height:
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
# TODO(rcadene): can we set rgb8 directly?
|
||||
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
|
||||
config.enable_stream(
|
||||
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.color)
|
||||
|
||||
if self.use_depth:
|
||||
if self.fps and self.width and self.height:
|
||||
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
config.enable_stream(
|
||||
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.depth)
|
||||
|
||||
@@ -316,18 +344,18 @@ class IntelRealSenseCamera:
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||
)
|
||||
if self.width is not None and self.width != actual_width:
|
||||
if self.capture_width is not None and self.capture_width != actual_width:
|
||||
raise OSError(
|
||||
f"Can't set {self.width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||
f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||
)
|
||||
if self.height is not None and self.height != actual_height:
|
||||
if self.capture_height is not None and self.capture_height != actual_height:
|
||||
raise OSError(
|
||||
f"Can't set {self.height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||
f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.fps = round(actual_fps)
|
||||
self.width = round(actual_width)
|
||||
self.height = round(actual_height)
|
||||
self.capture_width = round(actual_width)
|
||||
self.capture_height = round(actual_height)
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
@@ -347,7 +375,7 @@ class IntelRealSenseCamera:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -373,7 +401,7 @@ class IntelRealSenseCamera:
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
if h != self.height or w != self.width:
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
raise OSError(
|
||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
@@ -395,7 +423,7 @@ class IntelRealSenseCamera:
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
h, w = depth_map.shape
|
||||
if h != self.height or w != self.width:
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
raise OSError(
|
||||
f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
|
||||
"""
|
||||
@@ -66,7 +80,7 @@ def _find_cameras(
|
||||
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
|
||||
) -> list[int | str]:
|
||||
if mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -130,8 +144,8 @@ def save_images_from_cameras(
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
|
||||
f"height={camera.height}, color_mode={camera.color_mode})"
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, "
|
||||
f"height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
|
||||
@@ -230,9 +244,19 @@ class OpenCVCamera:
|
||||
else:
|
||||
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
||||
|
||||
# Store the raw (capture) resolution from the config.
|
||||
self.capture_width = config.width
|
||||
self.capture_height = config.height
|
||||
|
||||
# If rotated by ±90, swap width and height.
|
||||
if config.rotation in [-90, 90]:
|
||||
self.width = config.height
|
||||
self.height = config.width
|
||||
else:
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.mock = config.mock
|
||||
@@ -245,11 +269,10 @@ class OpenCVCamera:
|
||||
self.logs = {}
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# TODO(aliberts): Do we keep original width/height or do we define them after rotation?
|
||||
self.rotation = None
|
||||
if config.rotation == -90:
|
||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
@@ -263,7 +286,7 @@ class OpenCVCamera:
|
||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -271,10 +294,20 @@ class OpenCVCamera:
|
||||
# when other threads are used to save the images.
|
||||
cv2.setNumThreads(1)
|
||||
|
||||
backend = (
|
||||
cv2.CAP_V4L2
|
||||
if platform.system() == "Linux"
|
||||
else cv2.CAP_DSHOW
|
||||
if platform.system() == "Windows"
|
||||
else cv2.CAP_AVFOUNDATION
|
||||
if platform.system() == "Darwin"
|
||||
else cv2.CAP_ANY
|
||||
)
|
||||
|
||||
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
||||
# First create a temporary camera trying to access `camera_index`,
|
||||
# and verify it is a valid camera by calling `isOpened`.
|
||||
tmp_camera = cv2.VideoCapture(camera_idx)
|
||||
tmp_camera = cv2.VideoCapture(camera_idx, backend)
|
||||
is_camera_open = tmp_camera.isOpened()
|
||||
# Release camera to make it accessible for `find_camera_indices`
|
||||
tmp_camera.release()
|
||||
@@ -297,14 +330,14 @@ class OpenCVCamera:
|
||||
# Secondly, create the camera that will be used downstream.
|
||||
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
||||
# needs to be re-created.
|
||||
self.camera = cv2.VideoCapture(camera_idx)
|
||||
self.camera = cv2.VideoCapture(camera_idx, backend)
|
||||
|
||||
if self.fps is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||
if self.width is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
||||
if self.height is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
||||
if self.capture_width is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width)
|
||||
if self.capture_height is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height)
|
||||
|
||||
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
||||
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||
@@ -316,19 +349,22 @@ class OpenCVCamera:
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
||||
)
|
||||
if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
|
||||
if self.capture_width is not None and not math.isclose(
|
||||
self.capture_width, actual_width, rel_tol=1e-3
|
||||
):
|
||||
raise OSError(
|
||||
f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||
f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||
)
|
||||
if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
|
||||
if self.capture_height is not None and not math.isclose(
|
||||
self.capture_height, actual_height, rel_tol=1e-3
|
||||
):
|
||||
raise OSError(
|
||||
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||
f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.fps = round(actual_fps)
|
||||
self.width = round(actual_width)
|
||||
self.height = round(actual_height)
|
||||
|
||||
self.capture_width = round(actual_width)
|
||||
self.capture_height = round(actual_height)
|
||||
self.is_connected = True
|
||||
|
||||
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
||||
@@ -362,14 +398,14 @@ class OpenCVCamera:
|
||||
# so we convert the image color from BGR to RGB.
|
||||
if requested_color_mode == "rgb":
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
if h != self.height or w != self.width:
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
raise OSError(
|
||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
@@ -31,7 +45,7 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C
|
||||
|
||||
cameras[key] = IntelRealSenseCamera(cfg)
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
@@ -1,14 +1,25 @@
|
||||
import logging
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,7 +41,7 @@ class TeleoperateControlConfig(ControlConfig):
|
||||
fps: int | None = None
|
||||
teleop_time_s: float | None = None
|
||||
# Display all cameras on screen
|
||||
display_cameras: bool = True
|
||||
display_data: bool = False
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("record")
|
||||
@@ -43,11 +54,6 @@ class RecordControlConfig(ControlConfig):
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | Path | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
||||
device: str | None = None # cuda | cpu | mps
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int | None = None
|
||||
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
||||
@@ -76,7 +82,7 @@ class RecordControlConfig(ControlConfig):
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Display all cameras on screen
|
||||
display_cameras: bool = True
|
||||
display_data: bool = False
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
@@ -90,27 +96,6 @@ class RecordControlConfig(ControlConfig):
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
# When no device or use_amp are given, use the one from training config.
|
||||
if self.device is None or self.use_amp is None:
|
||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
||||
if self.device is None:
|
||||
self.device = train_cfg.device
|
||||
if self.use_amp is None:
|
||||
self.use_amp = train_cfg.use_amp
|
||||
|
||||
# Automatically switch to available device if necessary
|
||||
if not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("replay")
|
||||
@dataclass
|
||||
@@ -131,6 +116,11 @@ class ReplayControlConfig(ControlConfig):
|
||||
@dataclass
|
||||
class RemoteRobotConfig(ControlConfig):
|
||||
log_interval: int = 100
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp)
|
||||
viewer_ip: str | None = None
|
||||
viewer_port: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
########################################################################################
|
||||
# Utilities
|
||||
########################################################################################
|
||||
@@ -10,7 +24,7 @@ from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import cv2
|
||||
import rerun as rr
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from termcolor import colored
|
||||
@@ -18,6 +32,7 @@ from termcolor import colored
|
||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import get_features_from_robot
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||||
@@ -159,13 +174,13 @@ def warmup_record(
|
||||
events,
|
||||
enable_teleoperation,
|
||||
warmup_time_s,
|
||||
display_cameras,
|
||||
display_data,
|
||||
fps,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=warmup_time_s,
|
||||
display_cameras=display_cameras,
|
||||
display_data=display_data,
|
||||
events=events,
|
||||
fps=fps,
|
||||
teleoperate=enable_teleoperation,
|
||||
@@ -177,22 +192,18 @@ def record_episode(
|
||||
dataset,
|
||||
events,
|
||||
episode_time_s,
|
||||
display_cameras,
|
||||
display_data,
|
||||
policy,
|
||||
device,
|
||||
use_amp,
|
||||
fps,
|
||||
single_task,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=episode_time_s,
|
||||
display_cameras=display_cameras,
|
||||
display_data=display_data,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
policy=policy,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
@@ -204,12 +215,10 @@ def control_loop(
|
||||
robot,
|
||||
control_time_s=None,
|
||||
teleoperate=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
events=None,
|
||||
policy=None,
|
||||
device: torch.device | str | None = None,
|
||||
use_amp: bool | None = None,
|
||||
policy: PreTrainedPolicy = None,
|
||||
fps: int | None = None,
|
||||
single_task: str | None = None,
|
||||
):
|
||||
@@ -232,9 +241,6 @@ def control_loop(
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
if isinstance(device, str):
|
||||
device = get_safe_torch_device(device)
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -246,7 +252,9 @@ def control_loop(
|
||||
observation = robot.capture_observation()
|
||||
|
||||
if policy is not None:
|
||||
pred_action = predict_action(observation, policy, device, use_amp)
|
||||
pred_action = predict_action(
|
||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
)
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
action = robot.send_action(pred_action)
|
||||
@@ -256,11 +264,15 @@ def control_loop(
|
||||
frame = {**observation, **action, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
|
||||
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
|
||||
for k, v in action.items():
|
||||
for i, vv in enumerate(v):
|
||||
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
|
||||
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
rr.log(key, rr.Image(observation[key].numpy()), static=True)
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
@@ -289,15 +301,11 @@ def reset_environment(robot, events, reset_time_s, fps):
|
||||
)
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
def stop_recording(robot, listener, display_data):
|
||||
robot.disconnect()
|
||||
|
||||
if not is_headless():
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
if display_cameras:
|
||||
cv2.destroyAllWindows()
|
||||
if not is_headless() and listener is not None:
|
||||
listener.stop()
|
||||
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
@@ -318,7 +332,7 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -342,7 +356,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -632,7 +646,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -677,7 +691,7 @@ class DynamixelMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -743,7 +757,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -779,7 +793,7 @@ class DynamixelMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
@@ -299,7 +313,7 @@ class FeetechMotorsBus:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -323,7 +337,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -650,7 +664,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -688,7 +702,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -768,7 +782,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -804,7 +818,7 @@ class FeetechMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import (
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Sequence
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Logic to calibrate a robot arm built with dynamixel motors"""
|
||||
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Logic to calibrate a robot arm built with feetech motors"""
|
||||
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import threading
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains logic to instantiate a robot, read information from its motors and cameras,
|
||||
and send orders to its motors.
|
||||
"""
|
||||
@@ -460,7 +474,7 @@ class ManipulatorRobot:
|
||||
# Used when record_data=True
|
||||
follower_goal_pos[name] = goal_pos
|
||||
|
||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
||||
goal_pos = goal_pos.numpy().astype(np.float32)
|
||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||
|
||||
@@ -582,7 +596,7 @@ class ManipulatorRobot:
|
||||
action_sent.append(goal_pos)
|
||||
|
||||
# Send goal position to each follower
|
||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
||||
goal_pos = goal_pos.numpy().astype(np.float32)
|
||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||
|
||||
return torch.cat(action_sent)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
@@ -392,21 +406,19 @@ class MobileManipulator:
|
||||
for name in self.leader_arms:
|
||||
pos = self.leader_arms[name].read("Present_Position")
|
||||
pos_tensor = torch.from_numpy(pos).float()
|
||||
# Instead of pos_tensor.item(), use tolist() to convert the entire tensor to a list
|
||||
arm_positions.extend(pos_tensor.tolist())
|
||||
|
||||
# (The rest of your code for generating wheel commands remains unchanged)
|
||||
x_cmd = 0.0 # m/s forward/backward
|
||||
y_cmd = 0.0 # m/s lateral
|
||||
y_cmd = 0.0 # m/s forward/backward
|
||||
x_cmd = 0.0 # m/s lateral
|
||||
theta_cmd = 0.0 # deg/s rotation
|
||||
if self.pressed_keys["forward"]:
|
||||
x_cmd += xy_speed
|
||||
if self.pressed_keys["backward"]:
|
||||
x_cmd -= xy_speed
|
||||
if self.pressed_keys["left"]:
|
||||
y_cmd += xy_speed
|
||||
if self.pressed_keys["right"]:
|
||||
if self.pressed_keys["backward"]:
|
||||
y_cmd -= xy_speed
|
||||
if self.pressed_keys["left"]:
|
||||
x_cmd += xy_speed
|
||||
if self.pressed_keys["right"]:
|
||||
x_cmd -= xy_speed
|
||||
if self.pressed_keys["rotate_left"]:
|
||||
theta_cmd += theta_speed
|
||||
if self.pressed_keys["rotate_right"]:
|
||||
@@ -584,8 +596,8 @@ class MobileManipulator:
|
||||
# Create the body velocity vector [x, y, theta_rad].
|
||||
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
|
||||
|
||||
# Define the wheel mounting angles with a -90° offset.
|
||||
angles = np.radians(np.array([240, 120, 0]) - 90)
|
||||
# Define the wheel mounting angles (defined from y axis cw)
|
||||
angles = np.radians(np.array([300, 180, 60]))
|
||||
# Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed.
|
||||
# The third column (base_radius) accounts for the effect of rotation.
|
||||
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
||||
@@ -641,8 +653,8 @@ class MobileManipulator:
|
||||
# Compute each wheel’s linear speed (m/s) from its angular speed.
|
||||
wheel_linear_speeds = wheel_radps * wheel_radius
|
||||
|
||||
# Define the wheel mounting angles with a -90° offset.
|
||||
angles = np.radians(np.array([240, 120, 0]) - 90)
|
||||
# Define the wheel mounting angles (defined from y axis cw)
|
||||
angles = np.radians(np.array([300, 180, 60]))
|
||||
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
||||
|
||||
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from lerobot.common.robot_devices.robots.configs import (
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import platform
|
||||
import time
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
@@ -94,7 +94,7 @@ class MetricsTracker:
|
||||
metrics: dict[str, AverageMeter],
|
||||
initial_step: int = 0,
|
||||
):
|
||||
self.__dict__.update({k: None for k in self.__keys__})
|
||||
self.__dict__.update(dict.fromkeys(self.__keys__))
|
||||
self._batch_size = batch_size
|
||||
self._num_frames = num_frames
|
||||
self._avg_samples_per_ep = num_frames / num_episodes
|
||||
|
||||
@@ -51,8 +51,10 @@ def auto_select_torch_device() -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
try_device = str(try_device)
|
||||
match try_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
@@ -85,6 +87,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
||||
|
||||
|
||||
def is_torch_device_available(try_device: str) -> bool:
|
||||
try_device = str(try_device) # Ensure try_device is a string
|
||||
if try_device == "cuda":
|
||||
return torch.cuda.is_available()
|
||||
elif try_device == "mps":
|
||||
@@ -92,7 +95,7 @@ def is_torch_device_available(try_device: str) -> bool:
|
||||
elif try_device == "cpu":
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unknown device '{try_device}.")
|
||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
||||
|
||||
|
||||
def is_amp_available(device: str):
|
||||
|
||||
@@ -69,7 +69,13 @@ class WandBLogger:
|
||||
os.environ["WANDB_SILENT"] = "True"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None
|
||||
wandb_run_id = (
|
||||
cfg.wandb.run_id
|
||||
if cfg.wandb.run_id
|
||||
else get_wandb_run_id_from_filesystem(self.log_dir)
|
||||
if cfg.resume
|
||||
else None
|
||||
)
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=self.cfg.project,
|
||||
@@ -84,6 +90,7 @@ class WandBLogger:
|
||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
|
||||
Reference in New Issue
Block a user