forked from tangger/lerobot
Compare commits
2 Commits
user/fraca
...
hf-papers
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
57491b44ee | ||
|
|
bb3d014677 |
@@ -360,7 +360,7 @@ with profile(
|
||||
If you want, you can cite this work with:
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
|
||||
@@ -55,7 +55,7 @@ conda install ffmpeg -c conda-forge
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
cd lerobot && pip install -e ".[feetech]"
|
||||
cd lerobot && pip install ".[feetech]"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
@@ -141,7 +141,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Note: These motors are currently limited. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
|
||||
@@ -61,7 +61,7 @@ conda install ffmpeg -c conda-forge
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
cd lerobot && pip install -e ".[feetech]"
|
||||
cd lerobot && pip install ".[feetech]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
|
||||
@@ -106,7 +106,7 @@ def worker_process(queue: queue.Queue, num_threads: int):
|
||||
class AsyncImageWriter:
|
||||
"""
|
||||
This class abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchronously, which is critical to control a robot and record data
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
When `num_processes=0`, it creates a threads pool of size `num_threads`.
|
||||
|
||||
@@ -944,7 +944,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def stop_image_writer(self) -> None:
|
||||
"""
|
||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||
remove the image_writer in order for the LeRobotDataset object to be picklable and parallelized.
|
||||
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.stop()
|
||||
|
||||
@@ -36,7 +36,7 @@ ALOHA_MOBILE_INFO = {
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"license": "mit",
|
||||
"url": "https://mobile-aloha.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2401.02117",
|
||||
"paper": "https://huggingface.co/papers/2401.02117",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{fu2024mobile,
|
||||
author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
|
||||
@@ -49,7 +49,7 @@ ALOHA_STATIC_INFO = {
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"license": "mit",
|
||||
"url": "https://tonyzhaozh.github.io/aloha/",
|
||||
"paper": "https://arxiv.org/abs/2304.13705",
|
||||
"paper": "https://huggingface.co/papers/2304.13705",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{Zhao2023LearningFB,
|
||||
title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
|
||||
@@ -57,13 +57,13 @@ ALOHA_STATIC_INFO = {
|
||||
journal={RSS},
|
||||
year={2023},
|
||||
volume={abs/2304.13705},
|
||||
url={https://arxiv.org/abs/2304.13705}
|
||||
url={https://huggingface.co/papers/2304.13705}
|
||||
}""").lstrip(),
|
||||
}
|
||||
PUSHT_INFO = {
|
||||
"license": "mit",
|
||||
"url": "https://diffusion-policy.cs.columbia.edu/",
|
||||
"paper": "https://arxiv.org/abs/2303.04137v5",
|
||||
"paper": "https://huggingface.co/papers/2303.04137v5",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{chi2024diffusionpolicy,
|
||||
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
|
||||
@@ -75,7 +75,7 @@ PUSHT_INFO = {
|
||||
XARM_INFO = {
|
||||
"license": "mit",
|
||||
"url": "https://www.nicklashansen.com/td-mpc/",
|
||||
"paper": "https://arxiv.org/abs/2203.04955",
|
||||
"paper": "https://huggingface.co/papers/2203.04955",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{Hansen2022tdmpc,
|
||||
title={Temporal Difference Learning for Model Predictive Control},
|
||||
@@ -244,7 +244,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/BUDS-website/",
|
||||
"paper": "https://arxiv.org/abs/2109.13841",
|
||||
"paper": "https://huggingface.co/papers/2109.13841",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{zhu2022bottom,
|
||||
title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
|
||||
@@ -261,7 +261,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/sailor/",
|
||||
"paper": "https://arxiv.org/abs/2210.11435",
|
||||
"paper": "https://huggingface.co/papers/2210.11435",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{nasiriany2022sailor,
|
||||
title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
|
||||
@@ -274,7 +274,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/sirius/",
|
||||
"paper": "https://arxiv.org/abs/2211.08416",
|
||||
"paper": "https://huggingface.co/papers/2211.08416",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{liu2022robot,
|
||||
title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
|
||||
@@ -298,14 +298,14 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://sites.google.com/view/cablerouting/home",
|
||||
"paper": "https://arxiv.org/abs/2307.08927",
|
||||
"paper": "https://huggingface.co/papers/2307.08927",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{luo2023multistage,
|
||||
author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
|
||||
title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
|
||||
journal = {arXiv pre-print},
|
||||
year = {2023},
|
||||
url = {https://arxiv.org/abs/2307.08927},
|
||||
url = {https://huggingface.co/papers/2307.08927},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_fanuc_manipulation": {
|
||||
@@ -322,7 +322,7 @@ DATASETS = {
|
||||
"berkeley_gnm_cory_hall": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/1709.10489",
|
||||
"paper": "https://huggingface.co/papers/1709.10489",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{kahn2018self,
|
||||
title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
|
||||
@@ -337,7 +337,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/recon-robot",
|
||||
"paper": "https://arxiv.org/abs/2104.05859",
|
||||
"paper": "https://huggingface.co/papers/2104.05859",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{shah2021rapid,
|
||||
title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
|
||||
@@ -351,7 +351,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/SACSoN-review",
|
||||
"paper": "https://arxiv.org/abs/2306.01874",
|
||||
"paper": "https://huggingface.co/papers/2306.01874",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{hirose2023sacson,
|
||||
title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
|
||||
@@ -363,7 +363,7 @@ DATASETS = {
|
||||
"berkeley_mvp": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/2203.06173",
|
||||
"paper": "https://huggingface.co/papers/2203.06173",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@InProceedings{Radosavovic2022,
|
||||
title = {Real-World Robot Learning with Masked Visual Pre-training},
|
||||
@@ -375,7 +375,7 @@ DATASETS = {
|
||||
"berkeley_rpt": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/2306.10007",
|
||||
"paper": "https://huggingface.co/papers/2306.10007",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{Radosavovic2023,
|
||||
title={Robot Learning with Sensorimotor Pre-training},
|
||||
@@ -388,7 +388,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://human-world-model.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2308.10901",
|
||||
"paper": "https://huggingface.co/papers/2308.10901",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{mendonca2023structured,
|
||||
title={Structured World Models from Human Videos},
|
||||
@@ -401,7 +401,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://play-fusion.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2312.04549",
|
||||
"paper": "https://huggingface.co/papers/2312.04549",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{chen2023playfusion,
|
||||
title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
|
||||
@@ -414,7 +414,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://robo-affordances.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2304.08488",
|
||||
"paper": "https://huggingface.co/papers/2304.08488",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{bahl2023affordances,
|
||||
title={Affordances from Human Videos as a Versatile Representation for Robotics},
|
||||
@@ -433,7 +433,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://diffusion-policy.cs.columbia.edu/",
|
||||
"paper": "https://arxiv.org/abs/2303.04137v5",
|
||||
"paper": "https://huggingface.co/papers/2303.04137",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{chi2023diffusionpolicy,
|
||||
title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
@@ -505,7 +505,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://droid-dataset.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2403.12945",
|
||||
"paper": "https://huggingface.co/papers/2403.12945",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{khazatsky2024droid,
|
||||
title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
|
||||
@@ -517,7 +517,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://functional-manipulation-benchmark.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2401.08553",
|
||||
"paper": "https://huggingface.co/papers/2401.08553",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{luo2024fmb,
|
||||
title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
|
||||
@@ -530,7 +530,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://openreview.net/forum?id=WuBv9-IGDUA",
|
||||
"paper": "https://arxiv.org/abs/2401.14502",
|
||||
"paper": "https://huggingface.co/papers/2401.14502",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{saxena2023multiresolution,
|
||||
title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
|
||||
@@ -575,7 +575,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://jyopari.github.io/VINN/",
|
||||
"paper": "https://arxiv.org/abs/2112.01511",
|
||||
"paper": "https://huggingface.co/papers/2112.01511",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{pari2021surprising,
|
||||
title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
|
||||
@@ -590,7 +590,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://play-to-policy.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2210.10047",
|
||||
"paper": "https://huggingface.co/papers/2210.10047",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{cui2022play,
|
||||
title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
|
||||
@@ -603,7 +603,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://rot-robot.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2206.15469",
|
||||
"paper": "https://huggingface.co/papers/2206.15469",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{haldar2023watch,
|
||||
title={Watch and match: Supercharging imitation with regularized optimal transport},
|
||||
@@ -633,7 +633,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/hydra-il-2023",
|
||||
"paper": "https://arxiv.org/abs/2306.17237",
|
||||
"paper": "https://huggingface.co/papers/2306.17237",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{belkhale2023hydra,
|
||||
title={HYDRA: Hybrid Robot Actions for Imitation Learning},
|
||||
@@ -646,21 +646,21 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/visionandtouch",
|
||||
"paper": "https://arxiv.org/abs/1810.10191",
|
||||
"paper": "https://huggingface.co/papers/1810.10191",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{lee2019icra,
|
||||
title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
|
||||
author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
|
||||
booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
year={2019},
|
||||
url={https://arxiv.org/abs/1810.10191}
|
||||
url={https://huggingface.co/papers/1810.10191}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_robocook": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://hshi74.github.io/robocook/",
|
||||
"paper": "https://arxiv.org/abs/2306.14447",
|
||||
"paper": "https://huggingface.co/papers/2306.14447",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{shi2023robocook,
|
||||
title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
|
||||
@@ -673,7 +673,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
|
||||
"paper": "https://arxiv.org/abs/2209.08959, https://arxiv.org/abs/2210.01911",
|
||||
"paper": "https://huggingface.co/papers/2209.08959, https://huggingface.co/papers/2210.01911",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{rosete2022tacorl,
|
||||
author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
|
||||
@@ -693,7 +693,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "URL",
|
||||
"paper": "https://arxiv.org/abs/2107.05842",
|
||||
"paper": "https://huggingface.co/papers/2107.05842",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@Article{Osa22,
|
||||
author = {Takayuki Osa},
|
||||
@@ -709,7 +709,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://toto-benchmark.org/",
|
||||
"paper": "https://arxiv.org/abs/2306.00942",
|
||||
"paper": "https://huggingface.co/papers/2306.00942",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{zhou2023train,
|
||||
author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
|
||||
@@ -733,7 +733,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://owmcorl.github.io/#",
|
||||
"paper": "https://arxiv.org/abs/2310.16029",
|
||||
"paper": "https://huggingface.co/papers/2310.16029",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@preprint{Feng2023Finetuning,
|
||||
title={Finetuning Offline World Models in the Real World},
|
||||
@@ -745,7 +745,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://robopil.github.io/d3fields/",
|
||||
"paper": "https://arxiv.org/abs/2309.16118",
|
||||
"paper": "https://huggingface.co/papers/2309.16118",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{wang2023d3field,
|
||||
title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
|
||||
@@ -758,7 +758,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://uscresl.github.io/dmfd/",
|
||||
"paper": "https://arxiv.org/abs/2207.10148",
|
||||
"paper": "https://huggingface.co/papers/2207.10148",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{salhotra2022dmfd,
|
||||
author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
|
||||
@@ -775,7 +775,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/MUTEX/",
|
||||
"paper": "https://arxiv.org/abs/2309.14320",
|
||||
"paper": "https://huggingface.co/papers/2309.14320",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{shah2023mutex,
|
||||
title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
|
||||
@@ -811,7 +811,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://saytap.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2306.07580",
|
||||
"paper": "https://huggingface.co/papers/2306.07580",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{saytap2023,
|
||||
author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
|
||||
@@ -847,7 +847,7 @@ DATASETS = {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/VIOLA/",
|
||||
"paper": "https://arxiv.org/abs/2210.11339",
|
||||
"paper": "https://huggingface.co/papers/2210.11339",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{zhu2022viola,
|
||||
title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
|
||||
|
||||
@@ -101,7 +101,7 @@ def decode_video_frames_torchvision(
|
||||
keyframes_only = False
|
||||
torchvision.set_video_backend(backend)
|
||||
if backend == "pyav":
|
||||
keyframes_only = True # pyav doesn't support accurate seek
|
||||
keyframes_only = True # pyav doesnt support accuracte seek
|
||||
|
||||
# set a video stream reader
|
||||
# TODO(rcadene): also load audio stream at the same time
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
"""Action Chunking Transformer Policy
|
||||
|
||||
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
||||
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://huggingface.co/papers/2304.13705).
|
||||
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
|
||||
"""
|
||||
|
||||
@@ -41,7 +41,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
Hardware (paper: https://huggingface.co/papers/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
"""
|
||||
|
||||
config_class = ACTConfig
|
||||
@@ -161,7 +161,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
# (See App. B of https://huggingface.co/papers/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
@@ -175,7 +175,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
|
||||
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
|
||||
"""Temporal ensembling as described in Algorithm 2 of https://huggingface.co/papers/2304.13705.
|
||||
|
||||
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
|
||||
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
|
||||
|
||||
@@ -81,7 +81,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
|
||||
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
|
||||
network. This is the output dimension of that network, i.e., the embedding dimension.
|
||||
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
||||
use_film_scale_modulation: FiLM (https://huggingface.co/papers/1709.07871) is used for the Unet conditioning.
|
||||
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||
modulation.
|
||||
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
|
||||
|
||||
@@ -48,7 +48,7 @@ from lerobot.common.policies.utils import (
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
(paper: https://huggingface.co/papers/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
"""
|
||||
|
||||
config_class = DiffusionConfig
|
||||
@@ -370,7 +370,7 @@ class DiffusionModel(nn.Module):
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
||||
(https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation.
|
||||
|
||||
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
||||
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
||||
@@ -728,7 +728,7 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||
|
||||
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
# FiLM modulation (https://huggingface.co/papers/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||
|
||||
|
||||
@@ -357,7 +357,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"""
|
||||
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
|
||||
|
||||
[Paper](https://arxiv.org/abs/2501.09747)
|
||||
[Paper](https://huggingface.co/papers/2501.09747)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
@@ -516,7 +516,7 @@ class PI0FAST(nn.Module):
|
||||
interpolate_like_pi=self.config.interpolate_like_pi,
|
||||
)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
|
||||
@@ -17,8 +17,8 @@
|
||||
"""Implementation of Finetuning Offline World Models in the Real World.
|
||||
|
||||
The comments in this code may sometimes refer to these references:
|
||||
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
|
||||
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
|
||||
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://huggingface.co/papers/2203.04955)
|
||||
FOWM paper: Finetuning Offline World Models in the Real World (https://huggingface.co/papers/2310.16029)
|
||||
"""
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
@@ -162,7 +162,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
# loss: total loss of training RVQ
|
||||
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
|
||||
@@ -185,7 +185,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
||||
(https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation.
|
||||
|
||||
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
||||
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
||||
@@ -387,7 +387,7 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
|
||||
# mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251).
|
||||
# mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://huggingface.co/papers/2206.11251).
|
||||
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
||||
if len_additional_action_token > 0:
|
||||
features = torch.cat(
|
||||
@@ -824,8 +824,8 @@ class VqVae(nn.Module):
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
# this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181)
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://huggingface.co/papers/2403.03181)
|
||||
# this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://huggingface.co/papers/2403.03181)
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
with torch.no_grad():
|
||||
state_rep = self.encoder(state)
|
||||
@@ -838,7 +838,7 @@ class VqVae(nn.Module):
|
||||
return state_vq, vq_code
|
||||
|
||||
def vqvae_forward(self, state):
|
||||
# This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181).
|
||||
# This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://huggingface.co/papers/2403.03181).
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
# We start with passing action (or action chunk) at:t+n through the encoder ϕ.
|
||||
state_rep = self.encoder(state)
|
||||
|
||||
@@ -336,7 +336,7 @@ class ResidualVQ(nn.Module):
|
||||
"""
|
||||
Residual VQ is composed of multiple VectorQuantize layers.
|
||||
|
||||
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
||||
Follows Algorithm 1. in https://huggingface.co/papers/2107.03312
|
||||
"Residual Vector Quantizer (a.k.a. multi-stage vector quantizer [36]) cascades Nq layers of VQ as follows. The unquantized input vector is
|
||||
passed through a first VQ and quantization residuals are computed. The residuals are then iteratively quantized by a sequence of additional
|
||||
Nq -1 vector quantizers, as described in Algorithm 1."
|
||||
@@ -1006,7 +1006,7 @@ def gumbel_sample(
|
||||
if not straight_through or temperature <= 0.0 or not training:
|
||||
return ind, one_hot
|
||||
|
||||
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
|
||||
# use reinmax for better second-order accuracy - https://huggingface.co/papers/2304.08612
|
||||
# algorithm 2
|
||||
|
||||
if reinmax:
|
||||
@@ -1156,7 +1156,7 @@ def batched_embedding(indices, embeds):
|
||||
|
||||
|
||||
def orthogonal_loss_fn(t):
|
||||
# eq (2) from https://arxiv.org/abs/2112.00384
|
||||
# eq (2) from https://huggingface.co/papers/2112.00384
|
||||
h, n = t.shape[:2]
|
||||
normed_codes = F.normalize(t, p=2, dim=-1)
|
||||
cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes)
|
||||
|
||||
@@ -243,11 +243,6 @@ def control_loop(
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
# Controls starts, if policy is given it needs cleaning up
|
||||
if policy is not None:
|
||||
policy.reset()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
// fmt: off
|
||||
// flake8: noqa
|
||||
// !/usr/bin/env python
|
||||
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package async_inference;
|
||||
|
||||
// AsyncInference: from Robot perspective
|
||||
// Robot send observations to & executes action received from a remote Policy server
|
||||
service AsyncInference {
|
||||
// Robot -> Policy to share observations with a remote inference server
|
||||
// Policy -> Robot to share actions predicted for given observations
|
||||
rpc SendObservations(stream Observation) returns (Empty);
|
||||
rpc StreamActions(Empty) returns (stream Action);
|
||||
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
TRANSFER_MIDDLE = 2;
|
||||
TRANSFER_END = 3;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Observation {
|
||||
// sent by Robot, to remote Policy
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Action {
|
||||
// sent by remote Policy, to Robot
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message PolicySetup {
|
||||
// sent by Robot to remote server, to init Policy
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
@@ -1,48 +0,0 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: async_inference.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'async_inference.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"N\n\x06\x41\x63tion\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x0bPolicySetup\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xa9\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12\x42\n\rStreamActions\x12\x16.async_inference.Empty\x1a\x17.async_inference.Action0\x01\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=301
|
||||
_globals['_TRANSFERSTATE']._serialized_end=397
|
||||
_globals['_OBSERVATION']._serialized_start=42
|
||||
_globals['_OBSERVATION']._serialized_end=125
|
||||
_globals['_ACTION']._serialized_start=127
|
||||
_globals['_ACTION']._serialized_end=205
|
||||
_globals['_POLICYSETUP']._serialized_start=207
|
||||
_globals['_POLICYSETUP']._serialized_end=290
|
||||
_globals['_EMPTY']._serialized_start=292
|
||||
_globals['_EMPTY']._serialized_end=299
|
||||
_globals['_ASYNCINFERENCE']._serialized_start=400
|
||||
_globals['_ASYNCINFERENCE']._serialized_end=697
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,236 +0,0 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import async_inference_pb2 as async__inference__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.71.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class AsyncInferenceStub:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendObservations = channel.stream_unary(
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
request_serializer=async__inference__pb2.Observation.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.StreamActions = channel.unary_stream(
|
||||
'/async_inference.AsyncInference/StreamActions',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Action.FromString,
|
||||
_registered_method=True)
|
||||
self.SendPolicyInstructions = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
||||
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceServicer:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def SendObservations(self, request_iterator, context):
|
||||
"""Robot -> Policy to share observations with a remote inference server
|
||||
Policy -> Robot to share actions predicted for given observations
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def StreamActions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendPolicyInstructions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AsyncInferenceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendObservations,
|
||||
request_deserializer=async__inference__pb2.Observation.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'StreamActions': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamActions,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Action.SerializeToString,
|
||||
),
|
||||
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPolicyInstructions,
|
||||
request_deserializer=async__inference__pb2.PolicySetup.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'async_inference.AsyncInference', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AsyncInference:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendObservations(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
async__inference__pb2.Observation.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def StreamActions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/StreamActions',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Action.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendPolicyInstructions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
||||
async__inference__pb2.PolicySetup.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -1,12 +0,0 @@
|
||||
"""Server/Client side: Sometimes you just want the environment to wait a tiny bit"""
|
||||
|
||||
idle_wait = 0.01
|
||||
|
||||
"""Client side: The environment evolves with a time resolution equal to environment_dt"""
|
||||
environment_dt = 1 / 30
|
||||
|
||||
"""Server side: Running inference on (at most) environment_dt"""
|
||||
inference_latency = environment_dt
|
||||
|
||||
"""Supported policies"""
|
||||
supported_policies = ["act", "smolvla"]
|
||||
@@ -1,128 +0,0 @@
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def setup_logging(prefix: str, info_bracket: str):
|
||||
"""Sets up logging"""
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
|
||||
# Delete any existing prefix_* log files
|
||||
for old_log_file in os.listdir("logs"):
|
||||
if old_log_file.startswith(prefix) and old_log_file.endswith(".log"):
|
||||
try:
|
||||
os.remove(os.path.join("logs", old_log_file))
|
||||
print(f"Deleted old log file: {old_log_file}")
|
||||
except Exception as e:
|
||||
print(f"Failed to delete old log file {old_log_file}: {e}")
|
||||
|
||||
# Set up logging with both console and file output
|
||||
logger = logging.getLogger(prefix)
|
||||
# Prevent propagation to root logger to avoid duplicate messages
|
||||
logger.propagate = False
|
||||
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(
|
||||
logging.Formatter(
|
||||
f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# File handler - creates a new log file for each run
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
f"logs/policy_server_{int(time.time())}.log",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB
|
||||
backupCount=5,
|
||||
)
|
||||
file_handler.setFormatter(
|
||||
logging.Formatter(
|
||||
f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class TimedData:
|
||||
def __init__(self, timestamp: float, data: Any, timestep: int):
|
||||
"""Initialize a TimedData object.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp relative to data's creation.
|
||||
data: The actual data to wrap a timestamp around.
|
||||
timestep: The timestep of the data.
|
||||
"""
|
||||
self.timestamp = timestamp
|
||||
self.data = data
|
||||
self.timestep = timestep
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
|
||||
def get_timestamp(self):
|
||||
return self.timestamp
|
||||
|
||||
def get_timestep(self):
|
||||
return self.timestep
|
||||
|
||||
|
||||
class TimedAction(TimedData):
|
||||
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
|
||||
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
|
||||
|
||||
def get_action(self):
|
||||
return self.get_data()
|
||||
|
||||
|
||||
class TimedObservation(TimedData):
|
||||
def __init__(
|
||||
self,
|
||||
timestamp: float,
|
||||
observation: dict[str, torch.Tensor],
|
||||
timestep: int,
|
||||
transfer_state: int = 0,
|
||||
must_go: bool = False,
|
||||
):
|
||||
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
|
||||
self.transfer_state = transfer_state
|
||||
self.must_go = must_go
|
||||
|
||||
def get_observation(self):
|
||||
return self.get_data()
|
||||
|
||||
|
||||
class TinyPolicyConfig:
|
||||
def __init__(
|
||||
self,
|
||||
policy_type: str = "act",
|
||||
pretrained_name_or_path: str = "fracapuano/act_so100_test",
|
||||
device: str = "cpu",
|
||||
):
|
||||
self.policy_type = policy_type
|
||||
self.pretrained_name_or_path = pretrained_name_or_path
|
||||
self.device = device
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
"""Check if two observation states are similar, under a tolerance threshold"""
|
||||
return torch.linalg.norm(obs1_state - obs2_state) < atol
|
||||
|
||||
|
||||
def observations_similar(obs1: TimedObservation, obs2: TimedObservation, atol: float = 1) -> bool:
|
||||
"""Check if two observations are similar, under a tolerance threshold"""
|
||||
obs1_state = obs1.get_observation()["observation.state"]
|
||||
obs2_state = obs2.get_observation()["observation.state"]
|
||||
|
||||
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
||||
@@ -1,429 +0,0 @@
|
||||
import itertools
|
||||
import pickle # nosec
|
||||
import time
|
||||
from concurrent import futures
|
||||
from queue import Queue
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
import async_inference_pb2 # type: ignore
|
||||
import async_inference_pb2_grpc # type: ignore
|
||||
import grpc
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.common.policies.factory import get_policy_class
|
||||
from lerobot.scripts.server.constants import environment_dt, idle_wait, inference_latency, supported_policies
|
||||
from lerobot.scripts.server.helpers import (
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
TinyPolicyConfig,
|
||||
observations_similar,
|
||||
setup_logging,
|
||||
)
|
||||
|
||||
|
||||
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
prefix = "policy_server"
|
||||
info_bracket = "SERVER"
|
||||
logger = setup_logging(prefix, info_bracket)
|
||||
|
||||
def __init__(self):
|
||||
# Initialize dataset action generator (to debug this first version, will be removed in the future)
|
||||
self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset())
|
||||
|
||||
self._setup_server()
|
||||
|
||||
self.actions_per_chunk = 20
|
||||
self.actions_overlap = 10
|
||||
|
||||
self.running = True
|
||||
|
||||
def _setup_server(self) -> None:
|
||||
"""Flushes server state when new client connects."""
|
||||
# only running inference on the latest observation received by the server
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
self._predicted_timesteps = set()
|
||||
self._predicted_observations = Queue(maxsize=1)
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
client_id = context.peer()
|
||||
self.logger.info(f"Client {client_id} connected and ready")
|
||||
self._setup_server()
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def SendPolicyInstructions(self, request, context): # noqa: N802
|
||||
"""Receive policy instructions from the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Receiving policy instructions from {client_id}")
|
||||
|
||||
policy_specs = pickle.loads(request.data) # nosec
|
||||
assert isinstance(policy_specs, TinyPolicyConfig), (
|
||||
f"Policy specs must be a TinyPolicyConfig. Got {type(policy_specs)}"
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Policy type: {policy_specs.policy_type} | "
|
||||
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
|
||||
f"Device: {policy_specs.device}"
|
||||
)
|
||||
|
||||
assert policy_specs.policy_type in supported_policies, (
|
||||
f"Policy type {policy_specs.policy_type} not supported. Supported policies: {supported_policies}"
|
||||
)
|
||||
|
||||
self.device = policy_specs.device
|
||||
self.policy_type = policy_specs.policy_type # act, pi0, etc.
|
||||
|
||||
policy_class = get_policy_class(self.policy_type)
|
||||
|
||||
start = time.time()
|
||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
||||
self.policy.to(self.device)
|
||||
end = time.time()
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
||||
"""Receive observations from the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Receiving observations from {client_id}")
|
||||
|
||||
for observation in request_iterator:
|
||||
receive_time = time.time()
|
||||
timed_observation = pickle.loads(observation.data) # nosec
|
||||
deserialize_time = time.time()
|
||||
|
||||
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
|
||||
|
||||
if not self._maybe_enqueue_observation(timed_observation):
|
||||
continue
|
||||
|
||||
queue_time = time.time()
|
||||
|
||||
obs_timestep = timed_observation.get_timestep()
|
||||
obs_timestamp = timed_observation.get_timestamp()
|
||||
|
||||
self.logger.info(
|
||||
f"Received observation #{obs_timestep} | "
|
||||
f"Client timestamp: {obs_timestamp:.6f} | "
|
||||
f"Server timestamp: {receive_time:.6f} | "
|
||||
)
|
||||
|
||||
if not hasattr(self, "previous_obs_timestamp"):
|
||||
self.previous_obs_timestamp = obs_timestamp
|
||||
|
||||
self.logger.debug(
|
||||
f"1/DeltaObsT (~frequency): {1 / (1e-6 + obs_timestamp - self.previous_obs_timestamp):.6f} Hz| "
|
||||
f"Network latency: {receive_time - obs_timestamp:.6f}s | "
|
||||
f"Deserialization time: {deserialize_time - receive_time:.6f}s | "
|
||||
f"Queue time: {queue_time - deserialize_time:.6f}s | "
|
||||
)
|
||||
|
||||
self.previous_obs_timestamp = obs_timestamp
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def StreamActions(self, request, context): # noqa: N802
|
||||
"""Stream actions to the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Client {client_id} connected for action streaming")
|
||||
|
||||
# Generate action based on the most recent observation and its timestep
|
||||
try:
|
||||
obs = self.observation_queue.get()
|
||||
self.logger.info(
|
||||
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
|
||||
)
|
||||
|
||||
if obs:
|
||||
self.last_predicted_obs = obs
|
||||
self._predicted_timesteps.add(obs.get_timestep())
|
||||
start_time = time.time()
|
||||
action_chunk = self._predict_action_chunk(obs)
|
||||
# action_chunk = self._read_action_chunk(obs)
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
action_bytes = pickle.dumps(action_chunk) # nosec
|
||||
serialize_time = time.time() - start_time
|
||||
|
||||
# Create and return the Action
|
||||
action = async_inference_pb2.Action(transfer_state=obs.transfer_state, data=action_bytes)
|
||||
|
||||
self.logger.info(
|
||||
f"Action chunk #{obs.get_timestep()} generated | Inference time: {inference_time:.6f}s |"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Inference time: {inference_time:.6f}s |"
|
||||
f"Serialize time: {serialize_time:.6f}s |"
|
||||
f"Total time: {inference_time + serialize_time:.6f}s"
|
||||
)
|
||||
|
||||
yield action
|
||||
else:
|
||||
self.logger.warning("No observation in queue yet!")
|
||||
time.sleep(idle_wait)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in StreamActions: {e}")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def _enqueue_and_go(self, obs: TimedObservation):
|
||||
# If queue is full, get the old observation to make room
|
||||
if self.observation_queue.full():
|
||||
# pops from queue
|
||||
_ = self.observation_queue.get_nowait()
|
||||
self.logger.debug("Observation queue was full, removed oldest observation")
|
||||
|
||||
# Now put the new observation (never blocks as queue is non-full here)
|
||||
self.observation_queue.put(obs)
|
||||
return True
|
||||
|
||||
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
||||
if obs.get_timestep() in self._predicted_timesteps:
|
||||
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
|
||||
return False
|
||||
|
||||
elif observations_similar(obs, previous_obs, atol=1):
|
||||
self.logger.debug(
|
||||
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
|
||||
)
|
||||
return False
|
||||
|
||||
else:
|
||||
return True
|
||||
|
||||
def _maybe_enqueue_observation(self, obs: TimedObservation) -> bool:
|
||||
"""Enqueue an observation if it must go through processing, otherwise skip it.
|
||||
Observations not in queue are never run through the policy network"""
|
||||
|
||||
if obs.must_go or not hasattr(self, "last_predicted_obs"):
|
||||
self.logger.info(f"[MUST GO] Enqueued observation #{obs.get_timestep()} for direct processing!")
|
||||
return self._enqueue_and_go(obs)
|
||||
|
||||
else:
|
||||
if self._obs_sanity_checks(obs, self.last_predicted_obs):
|
||||
return self._enqueue_and_go(obs)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
|
||||
"""Turn a chunk of actions into a list of TimedAction instances,
|
||||
with the first action corresponding to t_0 and the rest corresponding to
|
||||
t_0 + i*environment_dt for i in range(len(action_chunk))
|
||||
"""
|
||||
return [
|
||||
TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def _run_act_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Run ACT-like policies"""
|
||||
start_time = time.time()
|
||||
|
||||
# prepare observation for policy forward pass
|
||||
batch = self.policy.normalize_inputs(observation)
|
||||
normalize_time = time.time()
|
||||
self.logger.debug(f"Observation normalization time: {normalize_time - start_time:.6f}s")
|
||||
|
||||
if self.policy.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.policy.config.image_features]
|
||||
prep_time = time.time()
|
||||
self.logger.debug(f"Observation image preparation time: {prep_time - normalize_time:.6f}s")
|
||||
|
||||
# forward pass outputs up to policy.config.n_action_steps != actions_per_chunk
|
||||
actions = self.policy.model(batch)[0][:, : self.actions_per_chunk]
|
||||
|
||||
actions = self.policy.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
end_time = time.time()
|
||||
self.logger.info(f"[ACT] Action chunk generation total time: {end_time - start_time:.6f}s")
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def _run_pi0_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Run PI0-like policies"""
|
||||
raise NotImplementedError("PI0 policy not implemented yet")
|
||||
|
||||
@torch.no_grad()
|
||||
def _run_smolvla_policy(
|
||||
self, observation: dict[str, torch.Tensor], noise: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Run smolvla-like policies"""
|
||||
observation = self.policy.normalize_inputs(observation)
|
||||
|
||||
images, img_masks = self.policy.prepare_images(observation)
|
||||
state = self.policy.prepare_state(observation)
|
||||
lang_tokens, lang_masks = self.policy.prepare_language(observation)
|
||||
|
||||
actions = self.policy.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
)
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.policy.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.policy.unnormalize_outputs(
|
||||
{"action": actions, "robot_type": [self.policy.config.robot_type]}
|
||||
)["action"]
|
||||
|
||||
return actions
|
||||
|
||||
def _get_action_chunk(
|
||||
self, observation: dict[str, torch.Tensor], policy_type: str = "act"
|
||||
) -> torch.Tensor:
|
||||
"""Get an action chunk from the policy"""
|
||||
if policy_type == "act":
|
||||
return self._run_act_policy(observation)
|
||||
elif policy_type == "smolvla":
|
||||
return self._run_smolvla_policy(observation)
|
||||
else:
|
||||
raise ValueError(f"Policy class {policy_type} not supported")
|
||||
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action based on the observation"""
|
||||
"""1. Prepare observation"""
|
||||
start_time = time.time()
|
||||
|
||||
observation = {
|
||||
"robot_type": [self.policy.config.robot_type],
|
||||
}
|
||||
for k, v in observation_t.get_observation().items():
|
||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions
|
||||
if "image" in k:
|
||||
# Add batch dimension first, then reorder to NCHW format, then normalize to [0, 1]
|
||||
observation[k] = (
|
||||
v.unsqueeze(0).permute(0, 3, 1, 2).to(self.device, non_blocking=True) / 255.0
|
||||
)
|
||||
else:
|
||||
observation[k] = v.unsqueeze(0).to(self.device, non_blocking=True)
|
||||
else:
|
||||
observation[k] = v # textual instructions are passed as a list of strings
|
||||
|
||||
prep_time = time.time()
|
||||
self.logger.debug(f"Observation preparation time: {prep_time - start_time:.6f}s")
|
||||
|
||||
"""2. Get action chunk"""
|
||||
action_tensor = self._get_action_chunk(observation, self.policy_type)
|
||||
action_tensor = action_tensor.squeeze(0)
|
||||
|
||||
# Move to CPU before serializing
|
||||
action_tensor = action_tensor.cpu()
|
||||
|
||||
post_inference_time = time.time()
|
||||
self.logger.debug(f"Post-inference processing start: {post_inference_time - prep_time:.6f}s")
|
||||
|
||||
if action_tensor.dim() == 1:
|
||||
# No chunk dimension, so repeat action to create a (dummy) chunk of actions
|
||||
action_tensor = action_tensor.repeat(self.actions_per_chunk, 1)
|
||||
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
|
||||
chunk_time = time.time()
|
||||
self.logger.debug(f"Action chunk creation time: {chunk_time - post_inference_time:.6f}s")
|
||||
time.sleep(
|
||||
max(0, inference_latency - max(0, chunk_time - start_time))
|
||||
) # sleep to control inference latency
|
||||
|
||||
return action_chunk
|
||||
|
||||
def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]:
|
||||
"""Stream chunks of actions from a prerecorded dataset.
|
||||
|
||||
Returns:
|
||||
Generator that yields chunks of actions from the dataset
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
||||
dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch")
|
||||
|
||||
# 1. Select the action column only, where you will find tensors with 6 elements
|
||||
actions = dataset["action"]
|
||||
action_indices = torch.arange(len(actions))
|
||||
|
||||
# 2. Chunk the iterable of tensors into chunks with 10 elements each
|
||||
# sending only first element for debugging
|
||||
indices_chunks = action_indices.unfold(
|
||||
0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap
|
||||
)
|
||||
|
||||
for idx_chunk in indices_chunks:
|
||||
yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :]
|
||||
|
||||
def _read_action_chunk(self, observation: Optional[TimedObservation] = None) -> list[TimedAction]:
|
||||
"""Dummy function for predicting action chunk given observation.
|
||||
|
||||
Instead of computing actions on-the-fly, this method streams
|
||||
actions from a prerecorded dataset.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
if not observation:
|
||||
observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0)
|
||||
|
||||
# Get chunk of actions from the generator
|
||||
actions_chunk = next(self.action_generator)
|
||||
|
||||
# Return a list of TimedActions, with timestamps starting from the observation timestamp
|
||||
actions_chunk = self._time_action_chunk(
|
||||
observation.get_timestamp(), actions_chunk, observation.get_timestep()
|
||||
)
|
||||
|
||||
chunk_time = time.time()
|
||||
self.logger.debug(f"Action chunk creation time: {chunk_time - start_time:.6f}s")
|
||||
|
||||
# slow action generation, emulates inference time
|
||||
time.sleep(max(0, inference_latency - max(0, chunk_time - start_time)))
|
||||
|
||||
return actions_chunk
|
||||
|
||||
def stop(self):
|
||||
"""Stop the server"""
|
||||
self.running = False
|
||||
self.logger.info("Server stopping...")
|
||||
|
||||
|
||||
def serve():
|
||||
port = 8080
|
||||
# Create the server instance first
|
||||
policy_server = PolicyServer()
|
||||
|
||||
# Setup and start gRPC server
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
server.add_insecure_port(f"[::]:{port}")
|
||||
server.start()
|
||||
policy_server.logger.info(f"PolicyServer started on port {port}")
|
||||
|
||||
try:
|
||||
# Use the running attribute to control server lifetime
|
||||
while policy_server.running:
|
||||
time.sleep(1) # Check every second instead of sleeping indefinitely
|
||||
|
||||
except KeyboardInterrupt:
|
||||
policy_server.stop()
|
||||
policy_server.logger.info("Keyboard interrupt received")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
@@ -1,608 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from queue import Empty, Queue
|
||||
from typing import Callable, Optional
|
||||
|
||||
import async_inference_pb2 # type: ignore
|
||||
import async_inference_pb2_grpc # type: ignore
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.scripts.server.constants import environment_dt, idle_wait
|
||||
from lerobot.scripts.server.helpers import TimedAction, TimedObservation, TinyPolicyConfig, setup_logging
|
||||
|
||||
|
||||
class RobotClient:
|
||||
prefix = "robot_client"
|
||||
info_bracket = "CLIENT"
|
||||
logger = setup_logging(prefix, info_bracket)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_address: Optional[str] = None,
|
||||
policy_type: str = "smolvla",
|
||||
pretrained_name_or_path: str = "lerobot/smolvla_base",
|
||||
policy_device: str = "cuda",
|
||||
chunk_size_threshold: float = 0.5,
|
||||
robot: str = "so100",
|
||||
):
|
||||
# Use environment variable if server_address is not provided
|
||||
if server_address is None:
|
||||
server_address = os.getenv("SERVER_ADDRESS", "localhost:8080")
|
||||
self.logger.info(f"No server address provided, using default address: {server_address}")
|
||||
|
||||
self.policy_config = TinyPolicyConfig(policy_type, pretrained_name_or_path, policy_device)
|
||||
self.channel = grpc.insecure_channel(server_address)
|
||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
self.logger.info(f"Initializing client to connect to server at {server_address}")
|
||||
|
||||
self.running = False
|
||||
self.must_go = True # does the observation qualify for direct processing on the policy server?
|
||||
|
||||
self.latest_action = -1
|
||||
self.action_chunk_size = -1
|
||||
|
||||
self._chunk_size_threshold = chunk_size_threshold
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
|
||||
|
||||
start_time = time.time()
|
||||
self.robot = make_robot(robot)
|
||||
self.robot.connect()
|
||||
|
||||
connect_time = time.time()
|
||||
self.logger.info(f"Robot connection time: {connect_time - start_time:.4f}s")
|
||||
|
||||
time.sleep(idle_wait) # sleep waiting for cameras to activate
|
||||
self.logger.info("Robot connected and ready")
|
||||
|
||||
def timestamps(self):
|
||||
"""Get the timestamps of the actions in the queue"""
|
||||
return sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
|
||||
def start(self):
|
||||
"""Start the robot client and connect to the policy server"""
|
||||
try:
|
||||
# client-server handshake
|
||||
start_time = time.time()
|
||||
self.stub.Ready(async_inference_pb2.Empty())
|
||||
end_time = time.time()
|
||||
self.logger.info(f"Connected to policy server in {end_time - start_time:.4f}s")
|
||||
|
||||
# send policy instructions
|
||||
policy_config_bytes = pickle.dumps(self.policy_config)
|
||||
policy_setup = async_inference_pb2.PolicySetup(
|
||||
transfer_state=async_inference_pb2.TRANSFER_BEGIN, data=policy_config_bytes
|
||||
)
|
||||
|
||||
self.logger.info("Sending policy instructions to policy server")
|
||||
self.logger.info(
|
||||
f"Policy type: {self.policy_config.policy_type} | "
|
||||
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
|
||||
f"Device: {self.policy_config.device}"
|
||||
)
|
||||
|
||||
self.stub.SendPolicyInstructions(policy_setup)
|
||||
|
||||
self.running = True
|
||||
self.available_actions_size = []
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Failed to connect to policy server: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self.running = False
|
||||
|
||||
self.robot.disconnect()
|
||||
self.logger.info("Robot disconnected")
|
||||
|
||||
self.channel.close()
|
||||
self.logger.info("Client stopped, channel closed")
|
||||
|
||||
def send_observation(
|
||||
self,
|
||||
obs: TimedObservation,
|
||||
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
Returns True if the observation was sent successfully, False otherwise."""
|
||||
if not self.running:
|
||||
self.logger.warning("Client not running")
|
||||
return False
|
||||
|
||||
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
|
||||
|
||||
start_time = time.time()
|
||||
observation_bytes = pickle.dumps(obs)
|
||||
serialize_time = time.time()
|
||||
self.logger.debug(f"Observation serialization time: {serialize_time - start_time:.6f}s")
|
||||
|
||||
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
|
||||
|
||||
try:
|
||||
send_start = time.time()
|
||||
_ = self.stub.SendObservations(iter([observation]))
|
||||
send_end = time.time()
|
||||
|
||||
obs_timestep = obs.get_timestep()
|
||||
|
||||
self.logger.info(
|
||||
f"Sent observation #{obs_timestep} | "
|
||||
f"Serialize time: {serialize_time - start_time:.6f}s | "
|
||||
f"Network time: {send_end - send_start:.6f}s | "
|
||||
f"Total time: {send_end - start_time:.6f}s"
|
||||
)
|
||||
|
||||
self.last_obs_sent_time = send_end
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
|
||||
return False
|
||||
|
||||
def _validate_action(self, action: TimedAction):
|
||||
"""Received actions are keps only when they have been produced for now or later, never before"""
|
||||
return not action.get_timestep() <= self.latest_action
|
||||
|
||||
def _inspect_action_queue(self):
|
||||
queue_size = self.action_queue.qsize()
|
||||
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
|
||||
return queue_size, timestamps
|
||||
|
||||
def _update_action_queue(self, actions: list[TimedAction]):
|
||||
"""Update the action queue with new actions, without ever emptying the queue"""
|
||||
|
||||
new_queue = Queue()
|
||||
for action in actions:
|
||||
if self._validate_action(action):
|
||||
new_queue.put(action)
|
||||
|
||||
self.action_queue = new_queue
|
||||
|
||||
def _aggregate_action_queues(
|
||||
self,
|
||||
incoming_actions: list[TimedAction],
|
||||
aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
):
|
||||
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
|
||||
# TODO(fracapuano): move outside of the function and make aggregate_fn an always required argument
|
||||
if not aggregate_fn:
|
||||
# default aggregate function: take the latest action
|
||||
def aggregate_fn(x1, x2):
|
||||
return x2
|
||||
|
||||
action_intersections: list[torch.Tensor] = []
|
||||
current_action_queue = {
|
||||
action.get_timestep(): action.get_action() for action in self.action_queue.queue
|
||||
}
|
||||
|
||||
for new_action in incoming_actions:
|
||||
if new_action.get_timestep() in current_action_queue:
|
||||
# TODO(fracapuano): There is probably a way to do this with broadcasting of the two action tensors
|
||||
action_intersections.append(
|
||||
TimedAction(
|
||||
timestamp=new_action.get_timestamp(),
|
||||
action=aggregate_fn(
|
||||
current_action_queue[new_action.get_timestep()], new_action.get_action()
|
||||
),
|
||||
timestep=new_action.get_timestep(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
action_intersections.append(new_action)
|
||||
|
||||
new_queue = Queue()
|
||||
for action in action_intersections:
|
||||
if self._validate_action(action):
|
||||
new_queue.put(action)
|
||||
|
||||
self.action_queue = new_queue
|
||||
|
||||
def _clear_action_queue(self):
|
||||
"""Clear the existing queue"""
|
||||
while not self.action_queue.empty():
|
||||
try:
|
||||
self.action_queue.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
def _fill_action_queue(self, actions: list[TimedAction]):
|
||||
"""Fill the action queue with incoming valid actions"""
|
||||
start_time = time.time()
|
||||
valid_count = 0
|
||||
|
||||
for action in actions:
|
||||
if self._validate_action(action):
|
||||
self.action_queue.put(action)
|
||||
valid_count += 1
|
||||
|
||||
end_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Queue filled: {valid_count}/{len(actions)} valid actions added in {end_time - start_time:.6f}s"
|
||||
)
|
||||
|
||||
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
|
||||
self._clear_action_queue()
|
||||
self._fill_action_queue(actions)
|
||||
|
||||
def receive_actions(self):
|
||||
"""Receive actions from the policy server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Action receiving thread starting")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
|
||||
receive_time = time.time()
|
||||
|
||||
# Deserialize bytes back into list[TimedAction]
|
||||
deserialize_start = time.time()
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
deserialize_end = time.time()
|
||||
|
||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
self.logger.info(f"Current latest action: {self.latest_action}")
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [self.latest_action] # queue was empty
|
||||
|
||||
# Log incoming actions
|
||||
incoming_timesteps = [a.get_timestep() for a in timed_actions]
|
||||
|
||||
# Calculate network latency if we have matching observations
|
||||
if len(timed_actions) > 0:
|
||||
first_action_timestep = timed_actions[0].get_timestep()
|
||||
server_to_client_latency = receive_time - self.last_obs_sent_time
|
||||
|
||||
self.logger.info(
|
||||
f"Received action chunk for step #{first_action_timestep} | "
|
||||
f"Latest action: #{self.latest_action} | "
|
||||
f"Network latency (server->client): {server_to_client_latency:.6f}s | "
|
||||
f"Deserialization time: {deserialize_end - deserialize_start:.6f}s"
|
||||
)
|
||||
|
||||
# Update action queue
|
||||
start_time = time.time()
|
||||
self._update_action_queue(timed_actions)
|
||||
queue_update_time = time.time() - start_time
|
||||
|
||||
self.must_go = (
|
||||
True # after receiving actions, next empty queue triggers must-go processing!
|
||||
)
|
||||
|
||||
# Get queue state after changes
|
||||
new_size, new_timesteps = self._inspect_action_queue()
|
||||
|
||||
self.logger.info(
|
||||
f"Queue update complete ({queue_update_time:.6f}s) | "
|
||||
f"Before: {old_size} items | "
|
||||
f"After: {new_size} items | "
|
||||
)
|
||||
self.logger.info(
|
||||
f"Latest action: {self.latest_action} | "
|
||||
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
|
||||
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error receiving actions: {e}")
|
||||
# Avoid tight loop on action receiver error
|
||||
time.sleep(idle_wait)
|
||||
|
||||
def _actions_available(self):
|
||||
"""Check if there are actions available in the queue"""
|
||||
return not self.action_queue.empty()
|
||||
|
||||
def _get_next_action(self) -> Optional[TimedAction]:
|
||||
"""Get the next action from the queue"""
|
||||
try:
|
||||
action = self.action_queue.get_nowait()
|
||||
return action
|
||||
|
||||
except Empty:
|
||||
return None
|
||||
|
||||
def _perform_action(self, timed_action: TimedAction):
|
||||
self.robot.send_action(timed_action.get_action())
|
||||
self.latest_action = timed_action.get_timestep()
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={timed_action.get_timestamp()} | "
|
||||
f"Action #{timed_action.get_timestep()} performed | "
|
||||
f"Queue size: {self.action_queue.qsize()}"
|
||||
)
|
||||
|
||||
def execute_actions(self):
|
||||
"""Continuously execute actions from the queue"""
|
||||
import warnings
|
||||
|
||||
warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2)
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
time.sleep(idle_wait) # wait for observation capture to start
|
||||
|
||||
self.logger.info("Action execution thread starting")
|
||||
|
||||
while self.running:
|
||||
# constantly monitor the size of the action queue
|
||||
self.available_actions_size.append(self.action_queue.qsize())
|
||||
|
||||
if self._actions_available():
|
||||
timed_action = self._get_next_action()
|
||||
self._perform_action(timed_action)
|
||||
|
||||
time.sleep(environment_dt)
|
||||
|
||||
else:
|
||||
self.logger.debug("No action available | Sleeping")
|
||||
time.sleep(idle_wait)
|
||||
|
||||
def stream_observations(self, get_observation_fn):
|
||||
"""Continuously stream observations to the server"""
|
||||
import warnings
|
||||
|
||||
warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2)
|
||||
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Observation streaming thread starting")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
start_time = time.time()
|
||||
observation = get_observation_fn()
|
||||
obs_capture_time = time.time() - start_time
|
||||
|
||||
self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s")
|
||||
|
||||
if not hasattr(self, "last_obs_timestamp"):
|
||||
self.last_obs_timestamp = observation.get_timestamp()
|
||||
|
||||
obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp()
|
||||
self.logger.info(
|
||||
f"Ts={obs_timestamp} | "
|
||||
f"Captured observation #{obs_timestep} | "
|
||||
f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}"
|
||||
)
|
||||
|
||||
self.last_obs_timestamp = obs_timestamp
|
||||
|
||||
# Set appropriate transfer state
|
||||
if obs_timestep == 0:
|
||||
state = async_inference_pb2.TRANSFER_BEGIN
|
||||
else:
|
||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||
|
||||
time.sleep(environment_dt)
|
||||
self.send_observation(observation, state)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in observation sender: {e}")
|
||||
time.sleep(idle_wait)
|
||||
|
||||
def control_loop_action(self):
|
||||
"""Reading and performing actions in local queue"""
|
||||
self.available_actions_size.append(self.action_queue.qsize())
|
||||
if self._actions_available():
|
||||
# Get action from queue
|
||||
get_start = time.time()
|
||||
timed_action = self._get_next_action()
|
||||
get_end = time.time() - get_start
|
||||
|
||||
self.logger.debug(
|
||||
f"Popping action from queue to perform took {get_end:.6f}s | "
|
||||
f"Queue size: {self.action_queue.qsize()}"
|
||||
)
|
||||
|
||||
self._perform_action(timed_action)
|
||||
|
||||
def _ready_to_send_observation(self):
|
||||
"""Flags when the client is ready to send an observation"""
|
||||
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
|
||||
|
||||
def control_loop_observation(self, get_observation_fn):
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
start_time = time.time()
|
||||
observation = get_observation_fn()
|
||||
obs_capture_time = time.time() - start_time
|
||||
|
||||
# If there are no actions left in the queue, the observation must go through processing!
|
||||
observation.must_go = self.must_go and self.action_queue.empty()
|
||||
self.logger.debug(f"QUEUE SIZE: {self.action_queue.qsize()} (Must go: {observation.must_go})")
|
||||
if observation.must_go:
|
||||
# must-go flag will be set again after receiving actions
|
||||
self.must_go = False
|
||||
|
||||
if not hasattr(self, "last_obs_timestamp"):
|
||||
self.last_obs_timestamp = observation.get_timestamp()
|
||||
|
||||
obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp()
|
||||
self.last_obs_timestamp = obs_timestamp
|
||||
|
||||
self.logger.info(
|
||||
f"Ts={obs_timestamp} | "
|
||||
f"Captured observation #{obs_timestep} | "
|
||||
f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s")
|
||||
|
||||
# Set appropriate transfer state
|
||||
if obs_timestep == 0:
|
||||
state = async_inference_pb2.TRANSFER_BEGIN
|
||||
else:
|
||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||
|
||||
self.send_observation(observation, state)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in observation sender: {e}")
|
||||
|
||||
def control_loop(self, get_observation_fn):
|
||||
"""Combined function for executing actions and streaming observations"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Control loop thread starting")
|
||||
|
||||
control_loops = 0
|
||||
while self.running:
|
||||
control_loop_start = time.time()
|
||||
self.control_loop_action()
|
||||
|
||||
"""Control loop: (2) Streaming observations to the remote policy server"""
|
||||
if self._ready_to_send_observation() or control_loops == 0:
|
||||
self.control_loop_observation(get_observation_fn)
|
||||
|
||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
||||
time.sleep(max(0, environment_dt - (time.time() - control_loop_start)))
|
||||
control_loops += 1
|
||||
|
||||
|
||||
def async_client(task_instruction: str, verbose: int = 0):
|
||||
client = RobotClient()
|
||||
|
||||
if client.start():
|
||||
# Function to get observations from the robot
|
||||
def get_observation():
|
||||
observation_content = None
|
||||
observation_content = client.robot.capture_observation()
|
||||
|
||||
observation_content["task"] = [task_instruction]
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0)
|
||||
)
|
||||
|
||||
return observation
|
||||
|
||||
client.logger.info("Starting all threads...")
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions)
|
||||
action_receiver_thread.daemon = True
|
||||
|
||||
control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,))
|
||||
control_loop_thread.daemon = True
|
||||
|
||||
# Start all threads
|
||||
action_receiver_thread.start()
|
||||
control_loop_thread.start()
|
||||
|
||||
try:
|
||||
while client.running:
|
||||
time.sleep(idle_wait)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
client.logger.info("Client stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Robot client for executing tasks via policy server")
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Task instruction for the robot to execute (e.g., 'fold my tshirt')",
|
||||
)
|
||||
parser.add_argument("--verbose", type=int, default=0, help="Verbosity level (default: 0)")
|
||||
parser.add_argument(
|
||||
"--server-port-address",
|
||||
type=str,
|
||||
default="localhost:8080",
|
||||
help="Server & port address (default: localhost:8080, or SERVER_ADDRESS env var)",
|
||||
)
|
||||
parser.add_argument("--policy-type", type=str, default="smolvla", help="Policy type (default: smolvla)")
|
||||
parser.add_argument(
|
||||
"--pretrained-name-or-path",
|
||||
type=str,
|
||||
default="lerobot/smolvla_base",
|
||||
help="Pretrained model name or path (default: lerobot/smolvla_base)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy-device", type=str, default="cuda", help="Device for policy inference (default: cuda)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-size-threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Chunk size threshold (`g` in the paper, default: 0.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot",
|
||||
type=str,
|
||||
default="so100",
|
||||
help="Robot name, as per the `make_robot` function (default: so100)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create client with parsed arguments
|
||||
client = RobotClient(
|
||||
server_address=args.server_address,
|
||||
policy_type=args.policy_type,
|
||||
pretrained_name_or_path=args.pretrained_name_or_path,
|
||||
policy_device=args.policy_device,
|
||||
chunk_size_threshold=args.chunk_size_threshold,
|
||||
robot=args.robot,
|
||||
)
|
||||
|
||||
if client.start():
|
||||
# Function to get observations from the robot
|
||||
def get_observation():
|
||||
observation_content = None
|
||||
observation_content = client.robot.capture_observation()
|
||||
|
||||
observation_content["task"] = [args.task]
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0)
|
||||
)
|
||||
|
||||
return observation
|
||||
|
||||
client.logger.info("Starting all threads...")
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions)
|
||||
action_receiver_thread.daemon = True
|
||||
|
||||
control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,))
|
||||
control_loop_thread.daemon = True
|
||||
|
||||
# Start all threads
|
||||
action_receiver_thread.start()
|
||||
control_loop_thread.start()
|
||||
|
||||
try:
|
||||
while client.running:
|
||||
time.sleep(idle_wait)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
client.logger.info("Client stopped")
|
||||
@@ -68,8 +68,8 @@ dependencies = [
|
||||
"pyzmq>=26.2.1",
|
||||
"rerun-sdk>=0.21.0",
|
||||
"termcolor>=2.4.0",
|
||||
"torch>=2.2.1",
|
||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torch>=2.2.1,<2.7",
|
||||
"torchcodec==0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torchvision>=0.21.0",
|
||||
"wandb>=0.16.3",
|
||||
"zarr>=2.17.0",
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6b1e600768a8771c5fe650e038a1193597e3810f032041b2a0d021e4496381c1
|
||||
oid sha256:0389a716d51c1c615fb2a3bfa386d89f00b0deca08c4fa21b23e020a939d0213
|
||||
size 3686488
|
||||
|
||||
@@ -28,7 +28,7 @@ from lerobot.common.datasets.transforms import (
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
|
||||
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
|
||||
DATASET_REPO_ID = "lerobot/aloha_static_cups_open"
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
|
||||
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9d4ebab73eabddc58879a4e770289d19e00a1a4cf2fa5fa33cd3a3246992bc90
|
||||
oid sha256:0dc691503e7d90b2086bb408e89a65f772ce5ee6e3562ef8c127bcb09bd90851
|
||||
size 40551392
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
|
||||
oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179
|
||||
size 5104
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1a7a8b1a457149109f843c32bcbb047d09de2201847b9b79f7501b447f77ecf4
|
||||
oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6
|
||||
size 31672
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5e6ce85296b2009e7c2060d336c0429b1c7197d9adb159e7df0ba18003067b36
|
||||
oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb
|
||||
size 68
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
|
||||
oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1
|
||||
size 33400
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
|
||||
oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a
|
||||
size 515400
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6
|
||||
oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e
|
||||
size 31672
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd
|
||||
oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d
|
||||
size 68
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
|
||||
oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842
|
||||
size 33400
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
@@ -254,14 +253,7 @@ def test_backward_compatibility_single_transforms(
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
@pytest.mark.skipif(
|
||||
version.parse(torch.__version__) < version.parse("2.7.0"),
|
||||
reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior",
|
||||
)
|
||||
def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
||||
# NOTE: PyTorch versions have different randomness, it might break this test.
|
||||
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
|
||||
|
||||
cfg = ImageTransformsConfig(enable=True)
|
||||
default_tf = ImageTransforms(cfg)
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ def test_diffuser_scheduler(optimizer):
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
@@ -55,6 +56,7 @@ def test_vqbet_scheduler(optimizer):
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
@@ -75,6 +77,7 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from pathlib import Path
|
||||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot import available_policies
|
||||
@@ -409,16 +408,7 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
||||
4. Check that this test now passes.
|
||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
||||
|
||||
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
||||
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
||||
https://github.com/huggingface/lerobot/pull/1127.
|
||||
|
||||
"""
|
||||
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
|
||||
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
|
||||
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
|
||||
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
||||
|
||||
Reference in New Issue
Block a user