Compare commits
2 Commits
user/fraca
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
26d2a13218 | ||
|
|
2be7f3a3ff |
@@ -9,4 +9,6 @@
|
||||
title: Assemble SO-101
|
||||
- local: getting_started_real_world_robot
|
||||
title: Getting Started with Real-World Robots
|
||||
- local: hilserl
|
||||
title: Getting Started with Reinforcement Learning
|
||||
title: "Tutorials"
|
||||
|
||||
512
docs/source/hilserl.mdx
Normal file
512
docs/source/hilserl.mdx
Normal file
@@ -0,0 +1,512 @@
|
||||
# HilSerl Real Robot Training Workflow Guide
|
||||
|
||||
Human-in-the-Loop Sample-Efficient Reinforcement Learning (HIL-SERL) with LeRobot workflow for taking a policy from “zero” to real-world robot mastery in just a couple of hours.
|
||||
It combines three ingredients:
|
||||
1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
|
||||
2. **On-robot actor / learner loop with human interventions:** a distributed SAC/RLPD learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
|
||||
3. **Safety & efficiency tools:** joint/EE bounds, impedance control, crop-ROI preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
|
||||
Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/hilserl-main-figure.png" alt="HIL-SERL workflow" title="HIL-SERL workflow" width="100%"></img>
|
||||
</p>
|
||||
|
||||
<p align="center"><i>HIL-SERL workflow, Luo et al. 2024</i></p>
|
||||
|
||||
This guide provides step-by-step instructions for training a robot policy using LeRobot's HilSerl implementation to train on a real robot.
|
||||
|
||||
|
||||
# 1. Real Robot Training Workflow
|
||||
|
||||
## 1.1 Understanding Configuration
|
||||
|
||||
The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/common/envs/configs.py`. Which is defined as:
|
||||
|
||||
```python
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
robot: Optional[RobotConfig] = None # Main robot agent (defined in `lerobot/common/robots`)
|
||||
teleop: Optional[TeleoperatorConfig] = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/common/teleoperators`)
|
||||
wrapper: Optional[EnvTransformConfig] = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py`
|
||||
fps: int = 10 # Control frequency
|
||||
name: str = "real_robot" # Environment name
|
||||
mode: str = None # "record", "replay", or None (for training)
|
||||
repo_id: Optional[str] = None # LeRobot dataset repository ID
|
||||
dataset_root: Optional[str] = None # Local dataset root (optional)
|
||||
task: str = "" # Task identifier
|
||||
num_episodes: int = 10 # Number of episodes for recording
|
||||
episode: int = 0 # episode index for replay
|
||||
device: str = "cuda" # Compute device
|
||||
push_to_hub: bool = True # Whether to push the recorded datasets to Hub
|
||||
pretrained_policy_name_or_path: Optional[str] = None # For policy loading
|
||||
reward_classifier_pretrained_path: Optional[str] = None # For reward model
|
||||
```
|
||||
|
||||
|
||||
## 1.2 Finding Robot Workspace Bounds
|
||||
|
||||
Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot.
|
||||
|
||||
This helps simplifying the problem of learning on the real robot by limiting the robot's operational space to a specific region that solves the task and avoids unnecessary or unsafe exploration.
|
||||
|
||||
### 1.2.1 Using find_joint_limits.py
|
||||
|
||||
This script helps you find the safe operational bounds for your robot's end-effector. Given that you have a follower and leader arm, you can use the script to find the bounds for the follower arm that will be applied during training.
|
||||
Bounding the action space will reduce the redundant exploration of the agent and guarantees safety.
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.find_joint_limits \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
```
|
||||
|
||||
### 1.2.2 Workflow
|
||||
|
||||
1. Run the script and move the robot through the space that solves the task
|
||||
2. The script will record the minimum and maximum end-effector positions and the joint angles and prints them to the console, for example:
|
||||
```
|
||||
Max ee position [0.24170487 0.201285 0.10273342]
|
||||
Min ee position [0.16631757 -0.08237468 0.03364977]
|
||||
Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0]
|
||||
Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
|
||||
```
|
||||
3. Use these values in the configuration of you teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field
|
||||
|
||||
### 1.2.3 Example Configuration
|
||||
|
||||
```json
|
||||
"end_effector_bounds": {
|
||||
"max": [0.24, 0.20, 0.10],
|
||||
"min": [0.16, -0.08, 0.03]
|
||||
}
|
||||
```
|
||||
|
||||
## 1.3 Collecting Demonstrations
|
||||
|
||||
With the bounds defined, you can safely collect demonstrations for training. Training RL with off-policy algorithm allows us to use offline datasets collected in order to improve the efficiency of the learning process.
|
||||
|
||||
### 1.3.1 Setting Up Record Mode
|
||||
|
||||
Create a configuration file for recording demonstrations (or edit an existing one like `env_config_so100.json`):
|
||||
|
||||
1. Set `mode` to `"record"`
|
||||
2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name")
|
||||
3. Set `num_episodes` to the number of demonstrations you want to collect
|
||||
4. Set `crop_params_dict` to `null` initially (we'll determine crops later)
|
||||
5. Configure `robot`, `cameras`, and other hardware settings
|
||||
|
||||
Example configuration section:
|
||||
```json
|
||||
"mode": "record",
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"dataset_root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes": 15,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
```
|
||||
|
||||
### 1.3.2 Gamepad Controls
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/gamepad_guide.jpg?raw=true" alt="Figure shows the control mappings on a Logitech gamepad." title="Gamepad Control Mapping" width="100%"></img>
|
||||
</p>
|
||||
<p align="center"><i>Gamepad button mapping for robot control and episode management</i></p>
|
||||
|
||||
|
||||
### 1.3.3 Recording Demonstrations
|
||||
|
||||
Start the recording process:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config_so100.json
|
||||
```
|
||||
|
||||
During recording:
|
||||
1. The robot will reset to the initial position defined in the configuration file `fixed_reset_position`
|
||||
2. Use the gamepad to control the robot by setting `"control_mode"="gamepad"` in the configuration file
|
||||
3. Complete the task successfully
|
||||
4. The episode ends with a reward of 1 when you press the "success" button
|
||||
5. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0
|
||||
6. You can rerecord an episode by pressing the "rerecord" button
|
||||
7. The process automatically continues to the next episode
|
||||
8. After recording all episodes, the dataset is pushed to the Hugging Face Hub (optional) and saved locally
|
||||
|
||||
|
||||
|
||||
## 1.4 Processing the Dataset
|
||||
|
||||
After collecting demonstrations, process them to determine optimal camera crops.
|
||||
Reinforcement learning is sensitive to background distractions, so it is important to crop the images to the relevant workspace area.
|
||||
Note: If you already know the crop parameters, you can skip this step and just set the `crop_params_dict` in the configuration file during recording.
|
||||
|
||||
### 1.4.1 Determining Crop Parameters
|
||||
|
||||
Use the `crop_dataset_roi.py` script to interactively select regions of interest in your camera images:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/rl/crop_dataset_roi.py --repo-id username/pick_lift_cube
|
||||
```
|
||||
|
||||
1. For each camera view, the script will display the first frame
|
||||
2. Draw a rectangle around the relevant workspace area
|
||||
3. Press 'c' to confirm the selection
|
||||
4. Repeat for all camera views
|
||||
5. The script outputs cropping parameters and creates a new cropped dataset
|
||||
|
||||
Example output:
|
||||
```
|
||||
Selected Rectangular Regions of Interest (top, left, height, width):
|
||||
observation.images.side: [180, 207, 180, 200]
|
||||
observation.images.front: [180, 250, 120, 150]
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/crop_dataset.gif" width="600"/>
|
||||
</p>
|
||||
|
||||
<p align="center"><i>Interactive cropping tool for selecting regions of interest</i></p>
|
||||
|
||||
|
||||
### 1.4.2 Updating Configuration
|
||||
|
||||
Add these crop parameters to your training configuration:
|
||||
|
||||
```json
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
```
|
||||
|
||||
## 1.5 Training with Actor-Learner
|
||||
|
||||
The LeRobot system uses a distributed actor-learner architecture for training. You will need to start two processes: a learner and an actor.
|
||||
|
||||
### 1.5.1 Configuration Setup
|
||||
|
||||
Create a training configuration file (See example `train_config_hilserl_so100.json`). The training config is based on the main `TrainPipelineConfig` class in `lerobot/configs/train.py`.
|
||||
|
||||
1. Set `mode` to `null` (for training mode)
|
||||
2. Configure the policy settings (`type`, `device`, etc.)
|
||||
3. Set `dataset` to your cropped dataset
|
||||
4. Configure environment settings with crop parameters
|
||||
5. Check the other parameters related to SAC.
|
||||
6. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
|
||||
### 1.5.2 Starting the Learner
|
||||
|
||||
First, start the learner server process:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/rl/learner.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
The learner:
|
||||
- Initializes the policy network
|
||||
- Prepares replay buffers
|
||||
- Opens a gRPC server to communicate with actors
|
||||
- Processes transitions and updates the policy
|
||||
|
||||
### 1.5.3 Starting the Actor
|
||||
|
||||
In a separate terminal, start the actor process with the same configuration:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
The actor:
|
||||
- Connects to the learner via gRPC
|
||||
- Initializes the environment
|
||||
- Execute rollouts of the policy to collect experience
|
||||
- Sends transitions to the learner
|
||||
- Receives updated policy parameters
|
||||
|
||||
### 1.5.4 Training Flow
|
||||
|
||||
The training proceeds automatically:
|
||||
|
||||
1. The actor executes the policy in the environment
|
||||
2. Transitions are collected and sent to the learner
|
||||
3. The learner updates the policy based on these transitions
|
||||
4. Updated policy parameters are sent back to the actor
|
||||
5. The process continues until the specified step limit is reached
|
||||
|
||||
### 1.5.5 Human in the Loop
|
||||
|
||||
- The key to learning efficiently is to have a human interventions to provide corrective feedback and completing the task to aide the policy learning and exploration.
|
||||
- To perform human interventions, you can press the upper right trigger button on the gamepad. This will pause the policy actions and allow you to take over.
|
||||
- A successful experiment is one where the human has to intervene at the start but then reduces the amount of interventions as the policy improves. You can monitor the intervention rate in the `wandb` dashboard.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/hil_effect.png?raw=true" alt="Figure shows the control mappings on a Logitech gamepad." title="Gamepad Control Mapping" width="100%"></img>
|
||||
</p>
|
||||
|
||||
<p align="center"><i>Example showing how human interventions help guide policy learning over time</i></p>
|
||||
|
||||
- The figure shows the plot of the episodic reward over interaction step. The figure shows the effect of human interventions on the policy learning.
|
||||
- The orange curve is an experiment without any human interventions. While the pink and blue curves are experiments with human interventions.
|
||||
- We can observe that the number of steps where the policy starts acheiving the maximum reward is cut by a quarter when human interventions are present.
|
||||
|
||||
#### Guide to Human Interventions
|
||||
The strategy to follow is to intervene heavily at the start of training and then reduce the amount of interventions as the training progresses. Some tips and hints:
|
||||
- Interevene for almost the length of the entire episode at the first few episodes.
|
||||
- When the policy is less chaotic, gradually reduce the intervention time during one episode and let the policy explore for a longer time.
|
||||
- Once the policy start guiding the robot towards acheiving the task, even if its not perfect, you can limit your interventions to simple quick actions like a grasping command, or grasp and lift command.
|
||||
|
||||
## 1.6 Monitoring and Debugging
|
||||
|
||||
If you have `wandb.enable` set to `true` in your configuration, you can monitor training progress in real-time through the [Weights & Biases](https://wandb.ai/site/) dashboard.
|
||||
|
||||
# 2. Training a Reward Classifier with LeRobot
|
||||
|
||||
This guide explains how to train a reward classifier for human-in-the-loop reinforcement learning implementation of LeRobot. Reward classifiers learn to predict the reward value given a state which can be used in an RL setup to train a policy.
|
||||
|
||||
|
||||
The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
||||
|
||||
## 2.1 Collecting a Dataset
|
||||
Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
||||
|
||||
To collect a dataset, you need to modeify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
### 2.1.1 Key Parameters for Data Collection:
|
||||
|
||||
- **mode**: set it to "record" to collect a dataset
|
||||
- **repo_id**: "hf_username/dataset_name", name of the dataset and repo on the hub
|
||||
- **num_episodes**: Number of episodes to record
|
||||
- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected
|
||||
- **fps**: Number of frames per second to record
|
||||
- **push_to_hub**: Whether to push the dataset to the hub
|
||||
|
||||
The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier.
|
||||
|
||||
Example configuration section for data collection:
|
||||
|
||||
```json
|
||||
{
|
||||
"mode": "record",
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"num_episodes": 20,
|
||||
"push_to_hub": true,
|
||||
"fps": 10,
|
||||
"number_of_steps_after_success": 15
|
||||
}
|
||||
```
|
||||
|
||||
## 2.2 Reward Classifier Configuration
|
||||
|
||||
The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters:
|
||||
|
||||
- **model_name**: Base model architecture (e.g., we mainly use "helper2424/resnet10")
|
||||
- **model_type**: "cnn" or "transformer"
|
||||
- **num_cameras**: Number of camera inputs
|
||||
- **num_classes**: Number of output classes (typically 2 for binary success/failure)
|
||||
- **hidden_dim**: Size of hidden representation
|
||||
- **dropout_rate**: Regularization parameter
|
||||
- **learning_rate**: Learning rate for optimizer
|
||||
|
||||
Example configuration from `reward_classifier_train_config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"policy": {
|
||||
"type": "reward_classifier",
|
||||
"model_name": "helper2424/resnet10",
|
||||
"model_type": "cnn",
|
||||
"num_cameras": 2,
|
||||
"num_classes": 2,
|
||||
"hidden_dim": 256,
|
||||
"dropout_rate": 0.1,
|
||||
"learning_rate": 1e-4,
|
||||
"device": "cuda",
|
||||
"use_amp": true,
|
||||
"input_features": {
|
||||
"observation.images.front": {
|
||||
"type": "VISUAL",
|
||||
"shape": [3, 128, 128]
|
||||
},
|
||||
"observation.images.side": {
|
||||
"type": "VISUAL",
|
||||
"shape": [3, 128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 2.3 Training the Classifier
|
||||
|
||||
To train the classifier, use the `train.py` script with your configuration:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config_path lerobot/configs/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
## 2.4 Deploying and Testing the Model
|
||||
|
||||
To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to use your model:
|
||||
|
||||
```python
|
||||
env_config = HILSerlRobotEnvConfig(
|
||||
reward_classifier_pretrained_path="path_to_your_pretrained_trained_model",
|
||||
# Other environment parameters
|
||||
)
|
||||
```
|
||||
or set the argument in the json config file.
|
||||
|
||||
```json
|
||||
{
|
||||
"reward_classifier_pretrained_path": "path_to_your_pretrained_model"
|
||||
}
|
||||
```
|
||||
|
||||
Run gym_manipulator.py to test the model.
|
||||
```bash
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json
|
||||
```
|
||||
|
||||
The reward classifier will automatically provide rewards based on the visual input from the robot's cameras.
|
||||
|
||||
## 2.5 Example Workflow
|
||||
|
||||
1. **Create the configuration files**:
|
||||
Create the necessary json configuration files for the reward classifier and the environment. Check the `json_examples` directory for examples.
|
||||
|
||||
2. **Collect a dataset**:
|
||||
```bash
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json
|
||||
```
|
||||
|
||||
3. **Train the classifier**:
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config_path lerobot/configs/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
4. **Test the classifier**:
|
||||
```bash
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json
|
||||
```
|
||||
# 3. Using gym_hil Simulation Environments with LeRobot
|
||||
|
||||
This guide explains how to use the `gym_hil` simulation environments as an alternative to real robots when working with the LeRobot framework for Human-In-the-Loop (HIL) reinforcement learning.
|
||||
|
||||
`gym_hil` is a package that provides Gymnasium-compatible simulation environments specifically designed for Human-In-the-Loop reinforcement learning. These environments allow you to:
|
||||
|
||||
- Train policies in simulation to test the RL stack before training on real robots
|
||||
|
||||
- Collect demonstrations in sim using external devices like gamepads or keyboards
|
||||
- Perform human interventions during policy learning
|
||||
|
||||
Currently, the main environment is a Franka Panda robot simulation based on MuJoCo, with tasks like picking up a cube.
|
||||
|
||||
## 3.1 Installation
|
||||
|
||||
First, install the `gym_hil` package within the LeRobot environment:
|
||||
|
||||
```bash
|
||||
pip install gym_hil
|
||||
|
||||
# Or in LeRobot
|
||||
cd lerobot
|
||||
pip install -e .[hilserl]
|
||||
```
|
||||
|
||||
## 3.2 Configuration
|
||||
|
||||
To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided in `gym_hil_env.json`. Key configuration sections include:
|
||||
|
||||
### 3.2.1 Environment Type and Task
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "hil",
|
||||
"name": "franka_sim",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Available tasks:
|
||||
- `PandaPickCubeBase-v0`: Basic environment
|
||||
- `PandaPickCubeGamepad-v0`: With gamepad control
|
||||
- `PandaPickCubeKeyboard-v0`: With keyboard control
|
||||
|
||||
### 3.2.2 Gym Wrappers Configuration
|
||||
|
||||
```json
|
||||
"wrapper": {
|
||||
"gripper_penalty": -0.02,
|
||||
"control_time_s": 15.0,
|
||||
"use_gripper": true,
|
||||
"fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785],
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
},
|
||||
"control_mode": "gamepad"
|
||||
}
|
||||
```
|
||||
|
||||
Important parameters:
|
||||
- `gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `use_gripper`: Whether to enable gripper control
|
||||
- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `control_mode`: Set to "gamepad" to use a gamepad controller
|
||||
|
||||
## 3.3 Running with HIL RL of LeRobot
|
||||
|
||||
### 3.3.1 Basic Usage
|
||||
|
||||
To run the environment, set mode to null:
|
||||
|
||||
```python
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
|
||||
### 3.3.2 Recording a Dataset
|
||||
|
||||
To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record:
|
||||
|
||||
```python
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
|
||||
### 3.3.3 Training a Policy
|
||||
|
||||
To train a policy, checkout the example json in `train_gym_hil_env.json` and run the actor and learner servers:
|
||||
|
||||
```python
|
||||
python lerobot/scripts/rl/actor.py --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
|
||||
In a different terminal, run the learner server:
|
||||
|
||||
```python
|
||||
python lerobot/scripts/rl/learner.py --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
|
||||
The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots.
|
||||
|
||||
Paper citation:
|
||||
|
||||
```
|
||||
@article{luo2024precise,
|
||||
title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning},
|
||||
author={Luo, Jianlan and Xu, Charles and Wu, Jeffrey and Levine, Sergey},
|
||||
journal={arXiv preprint arXiv:2410.21845},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
@@ -1,60 +0,0 @@
|
||||
// fmt: off
|
||||
// flake8: noqa
|
||||
// !/usr/bin/env python
|
||||
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package async_inference;
|
||||
|
||||
// AsyncInference: from Robot perspective
|
||||
// Robot send observations to & executes action received from a remote Policy server
|
||||
service AsyncInference {
|
||||
// Robot -> Policy to share observations with a remote inference server
|
||||
// Policy -> Robot to share actions predicted for given observations
|
||||
rpc SendObservations(stream Observation) returns (Empty);
|
||||
rpc StreamActions(Empty) returns (stream Action);
|
||||
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
TRANSFER_MIDDLE = 2;
|
||||
TRANSFER_END = 3;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Observation {
|
||||
// sent by Robot, to remote Policy
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Action {
|
||||
// sent by remote Policy, to Robot
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message PolicySetup {
|
||||
// sent by Robot to remote server, to init Policy
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
@@ -1,48 +0,0 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: async_inference.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'async_inference.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"N\n\x06\x41\x63tion\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x0bPolicySetup\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xa9\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12\x42\n\rStreamActions\x12\x16.async_inference.Empty\x1a\x17.async_inference.Action0\x01\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=301
|
||||
_globals['_TRANSFERSTATE']._serialized_end=397
|
||||
_globals['_OBSERVATION']._serialized_start=42
|
||||
_globals['_OBSERVATION']._serialized_end=125
|
||||
_globals['_ACTION']._serialized_start=127
|
||||
_globals['_ACTION']._serialized_end=205
|
||||
_globals['_POLICYSETUP']._serialized_start=207
|
||||
_globals['_POLICYSETUP']._serialized_end=290
|
||||
_globals['_EMPTY']._serialized_start=292
|
||||
_globals['_EMPTY']._serialized_end=299
|
||||
_globals['_ASYNCINFERENCE']._serialized_start=400
|
||||
_globals['_ASYNCINFERENCE']._serialized_end=697
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,236 +0,0 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import async_inference_pb2 as async__inference__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.71.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class AsyncInferenceStub:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendObservations = channel.stream_unary(
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
request_serializer=async__inference__pb2.Observation.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.StreamActions = channel.unary_stream(
|
||||
'/async_inference.AsyncInference/StreamActions',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Action.FromString,
|
||||
_registered_method=True)
|
||||
self.SendPolicyInstructions = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
||||
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceServicer:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def SendObservations(self, request_iterator, context):
|
||||
"""Robot -> Policy to share observations with a remote inference server
|
||||
Policy -> Robot to share actions predicted for given observations
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def StreamActions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendPolicyInstructions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AsyncInferenceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendObservations,
|
||||
request_deserializer=async__inference__pb2.Observation.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'StreamActions': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamActions,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Action.SerializeToString,
|
||||
),
|
||||
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPolicyInstructions,
|
||||
request_deserializer=async__inference__pb2.PolicySetup.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'async_inference.AsyncInference', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AsyncInference:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendObservations(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
async__inference__pb2.Observation.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def StreamActions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/StreamActions',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Action.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendPolicyInstructions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
||||
async__inference__pb2.PolicySetup.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -1,12 +0,0 @@
|
||||
"""Server/Client side: Sometimes you just want the environment to wait a tiny bit"""
|
||||
|
||||
idle_wait = 0.01
|
||||
|
||||
"""Client side: The environment evolves with a time resolution equal to environment_dt"""
|
||||
environment_dt = 1 / 30
|
||||
|
||||
"""Server side: Running inference on (at most) environment_dt"""
|
||||
inference_latency = environment_dt
|
||||
|
||||
"""Supported policies"""
|
||||
supported_policies = ["act", "smolvla"]
|
||||
@@ -1,128 +0,0 @@
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def setup_logging(prefix: str, info_bracket: str):
|
||||
"""Sets up logging"""
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
|
||||
# Delete any existing prefix_* log files
|
||||
for old_log_file in os.listdir("logs"):
|
||||
if old_log_file.startswith(prefix) and old_log_file.endswith(".log"):
|
||||
try:
|
||||
os.remove(os.path.join("logs", old_log_file))
|
||||
print(f"Deleted old log file: {old_log_file}")
|
||||
except Exception as e:
|
||||
print(f"Failed to delete old log file {old_log_file}: {e}")
|
||||
|
||||
# Set up logging with both console and file output
|
||||
logger = logging.getLogger(prefix)
|
||||
# Prevent propagation to root logger to avoid duplicate messages
|
||||
logger.propagate = False
|
||||
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(
|
||||
logging.Formatter(
|
||||
f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# File handler - creates a new log file for each run
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
f"logs/policy_server_{int(time.time())}.log",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB
|
||||
backupCount=5,
|
||||
)
|
||||
file_handler.setFormatter(
|
||||
logging.Formatter(
|
||||
f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class TimedData:
|
||||
def __init__(self, timestamp: float, data: Any, timestep: int):
|
||||
"""Initialize a TimedData object.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp relative to data's creation.
|
||||
data: The actual data to wrap a timestamp around.
|
||||
timestep: The timestep of the data.
|
||||
"""
|
||||
self.timestamp = timestamp
|
||||
self.data = data
|
||||
self.timestep = timestep
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
|
||||
def get_timestamp(self):
|
||||
return self.timestamp
|
||||
|
||||
def get_timestep(self):
|
||||
return self.timestep
|
||||
|
||||
|
||||
class TimedAction(TimedData):
|
||||
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
|
||||
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
|
||||
|
||||
def get_action(self):
|
||||
return self.get_data()
|
||||
|
||||
|
||||
class TimedObservation(TimedData):
|
||||
def __init__(
|
||||
self,
|
||||
timestamp: float,
|
||||
observation: dict[str, torch.Tensor],
|
||||
timestep: int,
|
||||
transfer_state: int = 0,
|
||||
must_go: bool = False,
|
||||
):
|
||||
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
|
||||
self.transfer_state = transfer_state
|
||||
self.must_go = must_go
|
||||
|
||||
def get_observation(self):
|
||||
return self.get_data()
|
||||
|
||||
|
||||
class TinyPolicyConfig:
|
||||
def __init__(
|
||||
self,
|
||||
policy_type: str = "act",
|
||||
pretrained_name_or_path: str = "fracapuano/act_so100_test",
|
||||
device: str = "cpu",
|
||||
):
|
||||
self.policy_type = policy_type
|
||||
self.pretrained_name_or_path = pretrained_name_or_path
|
||||
self.device = device
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
"""Check if two observation states are similar, under a tolerance threshold"""
|
||||
return torch.linalg.norm(obs1_state - obs2_state) < atol
|
||||
|
||||
|
||||
def observations_similar(obs1: TimedObservation, obs2: TimedObservation, atol: float = 1) -> bool:
|
||||
"""Check if two observations are similar, under a tolerance threshold"""
|
||||
obs1_state = obs1.get_observation()["observation.state"]
|
||||
obs2_state = obs2.get_observation()["observation.state"]
|
||||
|
||||
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
||||
@@ -1,429 +0,0 @@
|
||||
import itertools
|
||||
import pickle # nosec
|
||||
import time
|
||||
from concurrent import futures
|
||||
from queue import Queue
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
import async_inference_pb2 # type: ignore
|
||||
import async_inference_pb2_grpc # type: ignore
|
||||
import grpc
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.common.policies.factory import get_policy_class
|
||||
from lerobot.scripts.server.constants import environment_dt, idle_wait, inference_latency, supported_policies
|
||||
from lerobot.scripts.server.helpers import (
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
TinyPolicyConfig,
|
||||
observations_similar,
|
||||
setup_logging,
|
||||
)
|
||||
|
||||
|
||||
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
prefix = "policy_server"
|
||||
info_bracket = "SERVER"
|
||||
logger = setup_logging(prefix, info_bracket)
|
||||
|
||||
def __init__(self):
|
||||
# Initialize dataset action generator (to debug this first version, will be removed in the future)
|
||||
self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset())
|
||||
|
||||
self._setup_server()
|
||||
|
||||
self.actions_per_chunk = 20
|
||||
self.actions_overlap = 10
|
||||
|
||||
self.running = True
|
||||
|
||||
def _setup_server(self) -> None:
|
||||
"""Flushes server state when new client connects."""
|
||||
# only running inference on the latest observation received by the server
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
self._predicted_timesteps = set()
|
||||
self._predicted_observations = Queue(maxsize=1)
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
client_id = context.peer()
|
||||
self.logger.info(f"Client {client_id} connected and ready")
|
||||
self._setup_server()
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def SendPolicyInstructions(self, request, context): # noqa: N802
|
||||
"""Receive policy instructions from the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Receiving policy instructions from {client_id}")
|
||||
|
||||
policy_specs = pickle.loads(request.data) # nosec
|
||||
assert isinstance(policy_specs, TinyPolicyConfig), (
|
||||
f"Policy specs must be a TinyPolicyConfig. Got {type(policy_specs)}"
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Policy type: {policy_specs.policy_type} | "
|
||||
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
|
||||
f"Device: {policy_specs.device}"
|
||||
)
|
||||
|
||||
assert policy_specs.policy_type in supported_policies, (
|
||||
f"Policy type {policy_specs.policy_type} not supported. Supported policies: {supported_policies}"
|
||||
)
|
||||
|
||||
self.device = policy_specs.device
|
||||
self.policy_type = policy_specs.policy_type # act, pi0, etc.
|
||||
|
||||
policy_class = get_policy_class(self.policy_type)
|
||||
|
||||
start = time.time()
|
||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
||||
self.policy.to(self.device)
|
||||
end = time.time()
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
||||
"""Receive observations from the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Receiving observations from {client_id}")
|
||||
|
||||
for observation in request_iterator:
|
||||
receive_time = time.time()
|
||||
timed_observation = pickle.loads(observation.data) # nosec
|
||||
deserialize_time = time.time()
|
||||
|
||||
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
|
||||
|
||||
if not self._maybe_enqueue_observation(timed_observation):
|
||||
continue
|
||||
|
||||
queue_time = time.time()
|
||||
|
||||
obs_timestep = timed_observation.get_timestep()
|
||||
obs_timestamp = timed_observation.get_timestamp()
|
||||
|
||||
self.logger.info(
|
||||
f"Received observation #{obs_timestep} | "
|
||||
f"Client timestamp: {obs_timestamp:.6f} | "
|
||||
f"Server timestamp: {receive_time:.6f} | "
|
||||
)
|
||||
|
||||
if not hasattr(self, "previous_obs_timestamp"):
|
||||
self.previous_obs_timestamp = obs_timestamp
|
||||
|
||||
self.logger.debug(
|
||||
f"1/DeltaObsT (~frequency): {1 / (1e-6 + obs_timestamp - self.previous_obs_timestamp):.6f} Hz| "
|
||||
f"Network latency: {receive_time - obs_timestamp:.6f}s | "
|
||||
f"Deserialization time: {deserialize_time - receive_time:.6f}s | "
|
||||
f"Queue time: {queue_time - deserialize_time:.6f}s | "
|
||||
)
|
||||
|
||||
self.previous_obs_timestamp = obs_timestamp
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def StreamActions(self, request, context): # noqa: N802
|
||||
"""Stream actions to the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Client {client_id} connected for action streaming")
|
||||
|
||||
# Generate action based on the most recent observation and its timestep
|
||||
try:
|
||||
obs = self.observation_queue.get()
|
||||
self.logger.info(
|
||||
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
|
||||
)
|
||||
|
||||
if obs:
|
||||
self.last_predicted_obs = obs
|
||||
self._predicted_timesteps.add(obs.get_timestep())
|
||||
start_time = time.time()
|
||||
action_chunk = self._predict_action_chunk(obs)
|
||||
# action_chunk = self._read_action_chunk(obs)
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
action_bytes = pickle.dumps(action_chunk) # nosec
|
||||
serialize_time = time.time() - start_time
|
||||
|
||||
# Create and return the Action
|
||||
action = async_inference_pb2.Action(transfer_state=obs.transfer_state, data=action_bytes)
|
||||
|
||||
self.logger.info(
|
||||
f"Action chunk #{obs.get_timestep()} generated | Inference time: {inference_time:.6f}s |"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Inference time: {inference_time:.6f}s |"
|
||||
f"Serialize time: {serialize_time:.6f}s |"
|
||||
f"Total time: {inference_time + serialize_time:.6f}s"
|
||||
)
|
||||
|
||||
yield action
|
||||
else:
|
||||
self.logger.warning("No observation in queue yet!")
|
||||
time.sleep(idle_wait)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in StreamActions: {e}")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def _enqueue_and_go(self, obs: TimedObservation):
|
||||
# If queue is full, get the old observation to make room
|
||||
if self.observation_queue.full():
|
||||
# pops from queue
|
||||
_ = self.observation_queue.get_nowait()
|
||||
self.logger.debug("Observation queue was full, removed oldest observation")
|
||||
|
||||
# Now put the new observation (never blocks as queue is non-full here)
|
||||
self.observation_queue.put(obs)
|
||||
return True
|
||||
|
||||
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
||||
if obs.get_timestep() in self._predicted_timesteps:
|
||||
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
|
||||
return False
|
||||
|
||||
elif observations_similar(obs, previous_obs, atol=1):
|
||||
self.logger.debug(
|
||||
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
|
||||
)
|
||||
return False
|
||||
|
||||
else:
|
||||
return True
|
||||
|
||||
def _maybe_enqueue_observation(self, obs: TimedObservation) -> bool:
|
||||
"""Enqueue an observation if it must go through processing, otherwise skip it.
|
||||
Observations not in queue are never run through the policy network"""
|
||||
|
||||
if obs.must_go or not hasattr(self, "last_predicted_obs"):
|
||||
self.logger.info(f"[MUST GO] Enqueued observation #{obs.get_timestep()} for direct processing!")
|
||||
return self._enqueue_and_go(obs)
|
||||
|
||||
else:
|
||||
if self._obs_sanity_checks(obs, self.last_predicted_obs):
|
||||
return self._enqueue_and_go(obs)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
|
||||
"""Turn a chunk of actions into a list of TimedAction instances,
|
||||
with the first action corresponding to t_0 and the rest corresponding to
|
||||
t_0 + i*environment_dt for i in range(len(action_chunk))
|
||||
"""
|
||||
return [
|
||||
TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def _run_act_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Run ACT-like policies"""
|
||||
start_time = time.time()
|
||||
|
||||
# prepare observation for policy forward pass
|
||||
batch = self.policy.normalize_inputs(observation)
|
||||
normalize_time = time.time()
|
||||
self.logger.debug(f"Observation normalization time: {normalize_time - start_time:.6f}s")
|
||||
|
||||
if self.policy.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.policy.config.image_features]
|
||||
prep_time = time.time()
|
||||
self.logger.debug(f"Observation image preparation time: {prep_time - normalize_time:.6f}s")
|
||||
|
||||
# forward pass outputs up to policy.config.n_action_steps != actions_per_chunk
|
||||
actions = self.policy.model(batch)[0][:, : self.actions_per_chunk]
|
||||
|
||||
actions = self.policy.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
end_time = time.time()
|
||||
self.logger.info(f"[ACT] Action chunk generation total time: {end_time - start_time:.6f}s")
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def _run_pi0_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Run PI0-like policies"""
|
||||
raise NotImplementedError("PI0 policy not implemented yet")
|
||||
|
||||
@torch.no_grad()
|
||||
def _run_smolvla_policy(
|
||||
self, observation: dict[str, torch.Tensor], noise: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Run smolvla-like policies"""
|
||||
observation = self.policy.normalize_inputs(observation)
|
||||
|
||||
images, img_masks = self.policy.prepare_images(observation)
|
||||
state = self.policy.prepare_state(observation)
|
||||
lang_tokens, lang_masks = self.policy.prepare_language(observation)
|
||||
|
||||
actions = self.policy.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
)
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.policy.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.policy.unnormalize_outputs(
|
||||
{"action": actions, "robot_type": [self.policy.config.robot_type]}
|
||||
)["action"]
|
||||
|
||||
return actions
|
||||
|
||||
def _get_action_chunk(
|
||||
self, observation: dict[str, torch.Tensor], policy_type: str = "act"
|
||||
) -> torch.Tensor:
|
||||
"""Get an action chunk from the policy"""
|
||||
if policy_type == "act":
|
||||
return self._run_act_policy(observation)
|
||||
elif policy_type == "smolvla":
|
||||
return self._run_smolvla_policy(observation)
|
||||
else:
|
||||
raise ValueError(f"Policy class {policy_type} not supported")
|
||||
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action based on the observation"""
|
||||
"""1. Prepare observation"""
|
||||
start_time = time.time()
|
||||
|
||||
observation = {
|
||||
"robot_type": [self.policy.config.robot_type],
|
||||
}
|
||||
for k, v in observation_t.get_observation().items():
|
||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions
|
||||
if "image" in k:
|
||||
# Add batch dimension first, then reorder to NCHW format, then normalize to [0, 1]
|
||||
observation[k] = (
|
||||
v.unsqueeze(0).permute(0, 3, 1, 2).to(self.device, non_blocking=True) / 255.0
|
||||
)
|
||||
else:
|
||||
observation[k] = v.unsqueeze(0).to(self.device, non_blocking=True)
|
||||
else:
|
||||
observation[k] = v # textual instructions are passed as a list of strings
|
||||
|
||||
prep_time = time.time()
|
||||
self.logger.debug(f"Observation preparation time: {prep_time - start_time:.6f}s")
|
||||
|
||||
"""2. Get action chunk"""
|
||||
action_tensor = self._get_action_chunk(observation, self.policy_type)
|
||||
action_tensor = action_tensor.squeeze(0)
|
||||
|
||||
# Move to CPU before serializing
|
||||
action_tensor = action_tensor.cpu()
|
||||
|
||||
post_inference_time = time.time()
|
||||
self.logger.debug(f"Post-inference processing start: {post_inference_time - prep_time:.6f}s")
|
||||
|
||||
if action_tensor.dim() == 1:
|
||||
# No chunk dimension, so repeat action to create a (dummy) chunk of actions
|
||||
action_tensor = action_tensor.repeat(self.actions_per_chunk, 1)
|
||||
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
|
||||
chunk_time = time.time()
|
||||
self.logger.debug(f"Action chunk creation time: {chunk_time - post_inference_time:.6f}s")
|
||||
time.sleep(
|
||||
max(0, inference_latency - max(0, chunk_time - start_time))
|
||||
) # sleep to control inference latency
|
||||
|
||||
return action_chunk
|
||||
|
||||
def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]:
|
||||
"""Stream chunks of actions from a prerecorded dataset.
|
||||
|
||||
Returns:
|
||||
Generator that yields chunks of actions from the dataset
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
||||
dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch")
|
||||
|
||||
# 1. Select the action column only, where you will find tensors with 6 elements
|
||||
actions = dataset["action"]
|
||||
action_indices = torch.arange(len(actions))
|
||||
|
||||
# 2. Chunk the iterable of tensors into chunks with 10 elements each
|
||||
# sending only first element for debugging
|
||||
indices_chunks = action_indices.unfold(
|
||||
0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap
|
||||
)
|
||||
|
||||
for idx_chunk in indices_chunks:
|
||||
yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :]
|
||||
|
||||
def _read_action_chunk(self, observation: Optional[TimedObservation] = None) -> list[TimedAction]:
|
||||
"""Dummy function for predicting action chunk given observation.
|
||||
|
||||
Instead of computing actions on-the-fly, this method streams
|
||||
actions from a prerecorded dataset.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
if not observation:
|
||||
observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0)
|
||||
|
||||
# Get chunk of actions from the generator
|
||||
actions_chunk = next(self.action_generator)
|
||||
|
||||
# Return a list of TimedActions, with timestamps starting from the observation timestamp
|
||||
actions_chunk = self._time_action_chunk(
|
||||
observation.get_timestamp(), actions_chunk, observation.get_timestep()
|
||||
)
|
||||
|
||||
chunk_time = time.time()
|
||||
self.logger.debug(f"Action chunk creation time: {chunk_time - start_time:.6f}s")
|
||||
|
||||
# slow action generation, emulates inference time
|
||||
time.sleep(max(0, inference_latency - max(0, chunk_time - start_time)))
|
||||
|
||||
return actions_chunk
|
||||
|
||||
def stop(self):
|
||||
"""Stop the server"""
|
||||
self.running = False
|
||||
self.logger.info("Server stopping...")
|
||||
|
||||
|
||||
def serve():
|
||||
port = 8080
|
||||
# Create the server instance first
|
||||
policy_server = PolicyServer()
|
||||
|
||||
# Setup and start gRPC server
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
server.add_insecure_port(f"[::]:{port}")
|
||||
server.start()
|
||||
policy_server.logger.info(f"PolicyServer started on port {port}")
|
||||
|
||||
try:
|
||||
# Use the running attribute to control server lifetime
|
||||
while policy_server.running:
|
||||
time.sleep(1) # Check every second instead of sleeping indefinitely
|
||||
|
||||
except KeyboardInterrupt:
|
||||
policy_server.stop()
|
||||
policy_server.logger.info("Keyboard interrupt received")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
@@ -1,608 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from queue import Empty, Queue
|
||||
from typing import Callable, Optional
|
||||
|
||||
import async_inference_pb2 # type: ignore
|
||||
import async_inference_pb2_grpc # type: ignore
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.scripts.server.constants import environment_dt, idle_wait
|
||||
from lerobot.scripts.server.helpers import TimedAction, TimedObservation, TinyPolicyConfig, setup_logging
|
||||
|
||||
|
||||
class RobotClient:
|
||||
prefix = "robot_client"
|
||||
info_bracket = "CLIENT"
|
||||
logger = setup_logging(prefix, info_bracket)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_address: Optional[str] = None,
|
||||
policy_type: str = "smolvla",
|
||||
pretrained_name_or_path: str = "lerobot/smolvla_base",
|
||||
policy_device: str = "cuda",
|
||||
chunk_size_threshold: float = 0.5,
|
||||
robot: str = "so100",
|
||||
):
|
||||
# Use environment variable if server_address is not provided
|
||||
if server_address is None:
|
||||
server_address = os.getenv("SERVER_ADDRESS", "localhost:8080")
|
||||
self.logger.info(f"No server address provided, using default address: {server_address}")
|
||||
|
||||
self.policy_config = TinyPolicyConfig(policy_type, pretrained_name_or_path, policy_device)
|
||||
self.channel = grpc.insecure_channel(server_address)
|
||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
self.logger.info(f"Initializing client to connect to server at {server_address}")
|
||||
|
||||
self.running = False
|
||||
self.must_go = True # does the observation qualify for direct processing on the policy server?
|
||||
|
||||
self.latest_action = -1
|
||||
self.action_chunk_size = -1
|
||||
|
||||
self._chunk_size_threshold = chunk_size_threshold
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
|
||||
|
||||
start_time = time.time()
|
||||
self.robot = make_robot(robot)
|
||||
self.robot.connect()
|
||||
|
||||
connect_time = time.time()
|
||||
self.logger.info(f"Robot connection time: {connect_time - start_time:.4f}s")
|
||||
|
||||
time.sleep(idle_wait) # sleep waiting for cameras to activate
|
||||
self.logger.info("Robot connected and ready")
|
||||
|
||||
def timestamps(self):
|
||||
"""Get the timestamps of the actions in the queue"""
|
||||
return sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
|
||||
def start(self):
|
||||
"""Start the robot client and connect to the policy server"""
|
||||
try:
|
||||
# client-server handshake
|
||||
start_time = time.time()
|
||||
self.stub.Ready(async_inference_pb2.Empty())
|
||||
end_time = time.time()
|
||||
self.logger.info(f"Connected to policy server in {end_time - start_time:.4f}s")
|
||||
|
||||
# send policy instructions
|
||||
policy_config_bytes = pickle.dumps(self.policy_config)
|
||||
policy_setup = async_inference_pb2.PolicySetup(
|
||||
transfer_state=async_inference_pb2.TRANSFER_BEGIN, data=policy_config_bytes
|
||||
)
|
||||
|
||||
self.logger.info("Sending policy instructions to policy server")
|
||||
self.logger.info(
|
||||
f"Policy type: {self.policy_config.policy_type} | "
|
||||
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
|
||||
f"Device: {self.policy_config.device}"
|
||||
)
|
||||
|
||||
self.stub.SendPolicyInstructions(policy_setup)
|
||||
|
||||
self.running = True
|
||||
self.available_actions_size = []
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Failed to connect to policy server: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self.running = False
|
||||
|
||||
self.robot.disconnect()
|
||||
self.logger.info("Robot disconnected")
|
||||
|
||||
self.channel.close()
|
||||
self.logger.info("Client stopped, channel closed")
|
||||
|
||||
def send_observation(
|
||||
self,
|
||||
obs: TimedObservation,
|
||||
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
Returns True if the observation was sent successfully, False otherwise."""
|
||||
if not self.running:
|
||||
self.logger.warning("Client not running")
|
||||
return False
|
||||
|
||||
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
|
||||
|
||||
start_time = time.time()
|
||||
observation_bytes = pickle.dumps(obs)
|
||||
serialize_time = time.time()
|
||||
self.logger.debug(f"Observation serialization time: {serialize_time - start_time:.6f}s")
|
||||
|
||||
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
|
||||
|
||||
try:
|
||||
send_start = time.time()
|
||||
_ = self.stub.SendObservations(iter([observation]))
|
||||
send_end = time.time()
|
||||
|
||||
obs_timestep = obs.get_timestep()
|
||||
|
||||
self.logger.info(
|
||||
f"Sent observation #{obs_timestep} | "
|
||||
f"Serialize time: {serialize_time - start_time:.6f}s | "
|
||||
f"Network time: {send_end - send_start:.6f}s | "
|
||||
f"Total time: {send_end - start_time:.6f}s"
|
||||
)
|
||||
|
||||
self.last_obs_sent_time = send_end
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
|
||||
return False
|
||||
|
||||
def _validate_action(self, action: TimedAction):
|
||||
"""Received actions are keps only when they have been produced for now or later, never before"""
|
||||
return not action.get_timestep() <= self.latest_action
|
||||
|
||||
def _inspect_action_queue(self):
|
||||
queue_size = self.action_queue.qsize()
|
||||
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
|
||||
return queue_size, timestamps
|
||||
|
||||
def _update_action_queue(self, actions: list[TimedAction]):
|
||||
"""Update the action queue with new actions, without ever emptying the queue"""
|
||||
|
||||
new_queue = Queue()
|
||||
for action in actions:
|
||||
if self._validate_action(action):
|
||||
new_queue.put(action)
|
||||
|
||||
self.action_queue = new_queue
|
||||
|
||||
def _aggregate_action_queues(
|
||||
self,
|
||||
incoming_actions: list[TimedAction],
|
||||
aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
):
|
||||
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
|
||||
# TODO(fracapuano): move outside of the function and make aggregate_fn an always required argument
|
||||
if not aggregate_fn:
|
||||
# default aggregate function: take the latest action
|
||||
def aggregate_fn(x1, x2):
|
||||
return x2
|
||||
|
||||
action_intersections: list[torch.Tensor] = []
|
||||
current_action_queue = {
|
||||
action.get_timestep(): action.get_action() for action in self.action_queue.queue
|
||||
}
|
||||
|
||||
for new_action in incoming_actions:
|
||||
if new_action.get_timestep() in current_action_queue:
|
||||
# TODO(fracapuano): There is probably a way to do this with broadcasting of the two action tensors
|
||||
action_intersections.append(
|
||||
TimedAction(
|
||||
timestamp=new_action.get_timestamp(),
|
||||
action=aggregate_fn(
|
||||
current_action_queue[new_action.get_timestep()], new_action.get_action()
|
||||
),
|
||||
timestep=new_action.get_timestep(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
action_intersections.append(new_action)
|
||||
|
||||
new_queue = Queue()
|
||||
for action in action_intersections:
|
||||
if self._validate_action(action):
|
||||
new_queue.put(action)
|
||||
|
||||
self.action_queue = new_queue
|
||||
|
||||
def _clear_action_queue(self):
|
||||
"""Clear the existing queue"""
|
||||
while not self.action_queue.empty():
|
||||
try:
|
||||
self.action_queue.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
def _fill_action_queue(self, actions: list[TimedAction]):
|
||||
"""Fill the action queue with incoming valid actions"""
|
||||
start_time = time.time()
|
||||
valid_count = 0
|
||||
|
||||
for action in actions:
|
||||
if self._validate_action(action):
|
||||
self.action_queue.put(action)
|
||||
valid_count += 1
|
||||
|
||||
end_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Queue filled: {valid_count}/{len(actions)} valid actions added in {end_time - start_time:.6f}s"
|
||||
)
|
||||
|
||||
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
|
||||
self._clear_action_queue()
|
||||
self._fill_action_queue(actions)
|
||||
|
||||
def receive_actions(self):
|
||||
"""Receive actions from the policy server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Action receiving thread starting")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
|
||||
receive_time = time.time()
|
||||
|
||||
# Deserialize bytes back into list[TimedAction]
|
||||
deserialize_start = time.time()
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
deserialize_end = time.time()
|
||||
|
||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
self.logger.info(f"Current latest action: {self.latest_action}")
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [self.latest_action] # queue was empty
|
||||
|
||||
# Log incoming actions
|
||||
incoming_timesteps = [a.get_timestep() for a in timed_actions]
|
||||
|
||||
# Calculate network latency if we have matching observations
|
||||
if len(timed_actions) > 0:
|
||||
first_action_timestep = timed_actions[0].get_timestep()
|
||||
server_to_client_latency = receive_time - self.last_obs_sent_time
|
||||
|
||||
self.logger.info(
|
||||
f"Received action chunk for step #{first_action_timestep} | "
|
||||
f"Latest action: #{self.latest_action} | "
|
||||
f"Network latency (server->client): {server_to_client_latency:.6f}s | "
|
||||
f"Deserialization time: {deserialize_end - deserialize_start:.6f}s"
|
||||
)
|
||||
|
||||
# Update action queue
|
||||
start_time = time.time()
|
||||
self._update_action_queue(timed_actions)
|
||||
queue_update_time = time.time() - start_time
|
||||
|
||||
self.must_go = (
|
||||
True # after receiving actions, next empty queue triggers must-go processing!
|
||||
)
|
||||
|
||||
# Get queue state after changes
|
||||
new_size, new_timesteps = self._inspect_action_queue()
|
||||
|
||||
self.logger.info(
|
||||
f"Queue update complete ({queue_update_time:.6f}s) | "
|
||||
f"Before: {old_size} items | "
|
||||
f"After: {new_size} items | "
|
||||
)
|
||||
self.logger.info(
|
||||
f"Latest action: {self.latest_action} | "
|
||||
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
|
||||
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error receiving actions: {e}")
|
||||
# Avoid tight loop on action receiver error
|
||||
time.sleep(idle_wait)
|
||||
|
||||
def _actions_available(self):
|
||||
"""Check if there are actions available in the queue"""
|
||||
return not self.action_queue.empty()
|
||||
|
||||
def _get_next_action(self) -> Optional[TimedAction]:
|
||||
"""Get the next action from the queue"""
|
||||
try:
|
||||
action = self.action_queue.get_nowait()
|
||||
return action
|
||||
|
||||
except Empty:
|
||||
return None
|
||||
|
||||
def _perform_action(self, timed_action: TimedAction):
|
||||
self.robot.send_action(timed_action.get_action())
|
||||
self.latest_action = timed_action.get_timestep()
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={timed_action.get_timestamp()} | "
|
||||
f"Action #{timed_action.get_timestep()} performed | "
|
||||
f"Queue size: {self.action_queue.qsize()}"
|
||||
)
|
||||
|
||||
def execute_actions(self):
|
||||
"""Continuously execute actions from the queue"""
|
||||
import warnings
|
||||
|
||||
warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2)
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
time.sleep(idle_wait) # wait for observation capture to start
|
||||
|
||||
self.logger.info("Action execution thread starting")
|
||||
|
||||
while self.running:
|
||||
# constantly monitor the size of the action queue
|
||||
self.available_actions_size.append(self.action_queue.qsize())
|
||||
|
||||
if self._actions_available():
|
||||
timed_action = self._get_next_action()
|
||||
self._perform_action(timed_action)
|
||||
|
||||
time.sleep(environment_dt)
|
||||
|
||||
else:
|
||||
self.logger.debug("No action available | Sleeping")
|
||||
time.sleep(idle_wait)
|
||||
|
||||
def stream_observations(self, get_observation_fn):
|
||||
"""Continuously stream observations to the server"""
|
||||
import warnings
|
||||
|
||||
warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2)
|
||||
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Observation streaming thread starting")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
start_time = time.time()
|
||||
observation = get_observation_fn()
|
||||
obs_capture_time = time.time() - start_time
|
||||
|
||||
self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s")
|
||||
|
||||
if not hasattr(self, "last_obs_timestamp"):
|
||||
self.last_obs_timestamp = observation.get_timestamp()
|
||||
|
||||
obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp()
|
||||
self.logger.info(
|
||||
f"Ts={obs_timestamp} | "
|
||||
f"Captured observation #{obs_timestep} | "
|
||||
f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}"
|
||||
)
|
||||
|
||||
self.last_obs_timestamp = obs_timestamp
|
||||
|
||||
# Set appropriate transfer state
|
||||
if obs_timestep == 0:
|
||||
state = async_inference_pb2.TRANSFER_BEGIN
|
||||
else:
|
||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||
|
||||
time.sleep(environment_dt)
|
||||
self.send_observation(observation, state)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in observation sender: {e}")
|
||||
time.sleep(idle_wait)
|
||||
|
||||
def control_loop_action(self):
|
||||
"""Reading and performing actions in local queue"""
|
||||
self.available_actions_size.append(self.action_queue.qsize())
|
||||
if self._actions_available():
|
||||
# Get action from queue
|
||||
get_start = time.time()
|
||||
timed_action = self._get_next_action()
|
||||
get_end = time.time() - get_start
|
||||
|
||||
self.logger.debug(
|
||||
f"Popping action from queue to perform took {get_end:.6f}s | "
|
||||
f"Queue size: {self.action_queue.qsize()}"
|
||||
)
|
||||
|
||||
self._perform_action(timed_action)
|
||||
|
||||
def _ready_to_send_observation(self):
|
||||
"""Flags when the client is ready to send an observation"""
|
||||
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
|
||||
|
||||
def control_loop_observation(self, get_observation_fn):
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
start_time = time.time()
|
||||
observation = get_observation_fn()
|
||||
obs_capture_time = time.time() - start_time
|
||||
|
||||
# If there are no actions left in the queue, the observation must go through processing!
|
||||
observation.must_go = self.must_go and self.action_queue.empty()
|
||||
self.logger.debug(f"QUEUE SIZE: {self.action_queue.qsize()} (Must go: {observation.must_go})")
|
||||
if observation.must_go:
|
||||
# must-go flag will be set again after receiving actions
|
||||
self.must_go = False
|
||||
|
||||
if not hasattr(self, "last_obs_timestamp"):
|
||||
self.last_obs_timestamp = observation.get_timestamp()
|
||||
|
||||
obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp()
|
||||
self.last_obs_timestamp = obs_timestamp
|
||||
|
||||
self.logger.info(
|
||||
f"Ts={obs_timestamp} | "
|
||||
f"Captured observation #{obs_timestep} | "
|
||||
f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s")
|
||||
|
||||
# Set appropriate transfer state
|
||||
if obs_timestep == 0:
|
||||
state = async_inference_pb2.TRANSFER_BEGIN
|
||||
else:
|
||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||
|
||||
self.send_observation(observation, state)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in observation sender: {e}")
|
||||
|
||||
def control_loop(self, get_observation_fn):
|
||||
"""Combined function for executing actions and streaming observations"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Control loop thread starting")
|
||||
|
||||
control_loops = 0
|
||||
while self.running:
|
||||
control_loop_start = time.time()
|
||||
self.control_loop_action()
|
||||
|
||||
"""Control loop: (2) Streaming observations to the remote policy server"""
|
||||
if self._ready_to_send_observation() or control_loops == 0:
|
||||
self.control_loop_observation(get_observation_fn)
|
||||
|
||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
||||
time.sleep(max(0, environment_dt - (time.time() - control_loop_start)))
|
||||
control_loops += 1
|
||||
|
||||
|
||||
def async_client(task_instruction: str, verbose: int = 0):
|
||||
client = RobotClient()
|
||||
|
||||
if client.start():
|
||||
# Function to get observations from the robot
|
||||
def get_observation():
|
||||
observation_content = None
|
||||
observation_content = client.robot.capture_observation()
|
||||
|
||||
observation_content["task"] = [task_instruction]
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0)
|
||||
)
|
||||
|
||||
return observation
|
||||
|
||||
client.logger.info("Starting all threads...")
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions)
|
||||
action_receiver_thread.daemon = True
|
||||
|
||||
control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,))
|
||||
control_loop_thread.daemon = True
|
||||
|
||||
# Start all threads
|
||||
action_receiver_thread.start()
|
||||
control_loop_thread.start()
|
||||
|
||||
try:
|
||||
while client.running:
|
||||
time.sleep(idle_wait)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
client.logger.info("Client stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Robot client for executing tasks via policy server")
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Task instruction for the robot to execute (e.g., 'fold my tshirt')",
|
||||
)
|
||||
parser.add_argument("--verbose", type=int, default=0, help="Verbosity level (default: 0)")
|
||||
parser.add_argument(
|
||||
"--server-port-address",
|
||||
type=str,
|
||||
default="localhost:8080",
|
||||
help="Server & port address (default: localhost:8080, or SERVER_ADDRESS env var)",
|
||||
)
|
||||
parser.add_argument("--policy-type", type=str, default="smolvla", help="Policy type (default: smolvla)")
|
||||
parser.add_argument(
|
||||
"--pretrained-name-or-path",
|
||||
type=str,
|
||||
default="lerobot/smolvla_base",
|
||||
help="Pretrained model name or path (default: lerobot/smolvla_base)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy-device", type=str, default="cuda", help="Device for policy inference (default: cuda)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-size-threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Chunk size threshold (`g` in the paper, default: 0.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot",
|
||||
type=str,
|
||||
default="so100",
|
||||
help="Robot name, as per the `make_robot` function (default: so100)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create client with parsed arguments
|
||||
client = RobotClient(
|
||||
server_address=args.server_address,
|
||||
policy_type=args.policy_type,
|
||||
pretrained_name_or_path=args.pretrained_name_or_path,
|
||||
policy_device=args.policy_device,
|
||||
chunk_size_threshold=args.chunk_size_threshold,
|
||||
robot=args.robot,
|
||||
)
|
||||
|
||||
if client.start():
|
||||
# Function to get observations from the robot
|
||||
def get_observation():
|
||||
observation_content = None
|
||||
observation_content = client.robot.capture_observation()
|
||||
|
||||
observation_content["task"] = [args.task]
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0)
|
||||
)
|
||||
|
||||
return observation
|
||||
|
||||
client.logger.info("Starting all threads...")
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions)
|
||||
action_receiver_thread.daemon = True
|
||||
|
||||
control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,))
|
||||
control_loop_thread.daemon = True
|
||||
|
||||
# Start all threads
|
||||
action_receiver_thread.start()
|
||||
control_loop_thread.start()
|
||||
|
||||
try:
|
||||
while client.running:
|
||||
time.sleep(idle_wait)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
client.logger.info("Client stopped")
|
||||
@@ -63,7 +63,7 @@ dependencies = [
|
||||
"opencv-python-headless>=4.9.0",
|
||||
"packaging>=24.2",
|
||||
"av>=14.2.0",
|
||||
"pymunk>=6.6.0",
|
||||
"pymunk>=6.6.0,<7.0.0",
|
||||
"pynput>=1.7.7",
|
||||
"pyzmq>=26.2.1",
|
||||
"rerun-sdk>=0.21.0",
|
||||
|
||||
Reference in New Issue
Block a user