Compare commits
3 Commits
user/fraca
...
pre-commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d20ee7655 | ||
|
|
1537d0ab90 | ||
|
|
2be7f3a3ff |
24
.github/workflows/build-docker-images.yml
vendored
24
.github/workflows/build-docker-images.yml
vendored
@@ -40,24 +40,24 @@ jobs:
|
|||||||
git lfs install
|
git lfs install
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||||
with:
|
with:
|
||||||
cache-binary: false
|
cache-binary: false
|
||||||
|
|
||||||
- name: Check out code
|
- name: Check out code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
lfs: true
|
lfs: true
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
|
||||||
- name: Build and Push CPU
|
- name: Build and Push CPU
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/lerobot-cpu/Dockerfile
|
file: ./docker/lerobot-cpu/Dockerfile
|
||||||
@@ -78,24 +78,24 @@ jobs:
|
|||||||
git lfs install
|
git lfs install
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||||
with:
|
with:
|
||||||
cache-binary: false
|
cache-binary: false
|
||||||
|
|
||||||
- name: Check out code
|
- name: Check out code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
lfs: true
|
lfs: true
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
|
||||||
- name: Build and Push GPU
|
- name: Build and Push GPU
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/lerobot-gpu/Dockerfile
|
file: ./docker/lerobot-gpu/Dockerfile
|
||||||
@@ -110,23 +110,23 @@ jobs:
|
|||||||
group: aws-general-8-plus
|
group: aws-general-8-plus
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||||
with:
|
with:
|
||||||
cache-binary: false
|
cache-binary: false
|
||||||
|
|
||||||
- name: Check out code
|
- name: Check out code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
|
||||||
- name: Build and Push GPU dev
|
- name: Build and Push GPU dev
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/lerobot-gpu-dev/Dockerfile
|
file: ./docker/lerobot-gpu-dev/Dockerfile
|
||||||
|
|||||||
4
.github/workflows/nightly-tests.yml
vendored
4
.github/workflows/nightly-tests.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
runs-on:
|
runs-on:
|
||||||
group: aws-general-8-plus
|
group: aws-general-8-plus
|
||||||
container:
|
container:
|
||||||
image: huggingface/lerobot-cpu:latest
|
image: huggingface/lerobot-cpu:latest # zizmor: ignore[unpinned-images]
|
||||||
options: --shm-size "16gb"
|
options: --shm-size "16gb"
|
||||||
credentials:
|
credentials:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
@@ -60,7 +60,7 @@ jobs:
|
|||||||
CUDA_VISIBLE_DEVICES: "0"
|
CUDA_VISIBLE_DEVICES: "0"
|
||||||
TEST_TYPE: "single_gpu"
|
TEST_TYPE: "single_gpu"
|
||||||
container:
|
container:
|
||||||
image: huggingface/lerobot-gpu:latest
|
image: huggingface/lerobot-gpu:latest # zizmor: ignore[unpinned-images]
|
||||||
options: --gpus all --shm-size "16gb"
|
options: --gpus all --shm-size "16gb"
|
||||||
credentials:
|
credentials:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
|||||||
8
.github/workflows/quality.yml
vendored
8
.github/workflows/quality.yml
vendored
@@ -33,12 +33,12 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Repository
|
- name: Checkout Repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1
|
||||||
with:
|
with:
|
||||||
python-version: ${{ env.PYTHON_VERSION }}
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
@@ -64,9 +64,9 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Repository
|
- name: Checkout Repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: typos-action
|
- name: typos-action
|
||||||
uses: crate-ci/typos@v1.29.10
|
uses: crate-ci/typos@db35ee91e80fbb447f33b0e5fbddb24d2a1a884f # v1.29.10
|
||||||
|
|||||||
8
.github/workflows/test-docker-build.yml
vendored
8
.github/workflows/test-docker-build.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
|||||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||||
steps:
|
steps:
|
||||||
- name: Check out code
|
- name: Check out code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -64,17 +64,17 @@ jobs:
|
|||||||
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
|
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||||
with:
|
with:
|
||||||
cache-binary: false
|
cache-binary: false
|
||||||
|
|
||||||
- name: Check out code
|
- name: Check out code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Build Docker image
|
- name: Build Docker image
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||||
with:
|
with:
|
||||||
file: ${{ matrix.docker-file }}
|
file: ${{ matrix.docker-file }}
|
||||||
context: .
|
context: .
|
||||||
|
|||||||
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
lfs: true # Ensure LFS files are pulled
|
lfs: true # Ensure LFS files are pulled
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
@@ -62,7 +62,7 @@ jobs:
|
|||||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||||
|
|
||||||
- name: Install uv and python
|
- name: Install uv and python
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
version: ${{ env.UV_VERSION }}
|
version: ${{ env.UV_VERSION }}
|
||||||
@@ -85,7 +85,7 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
lfs: true # Ensure LFS files are pulled
|
lfs: true # Ensure LFS files are pulled
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
@@ -94,7 +94,7 @@ jobs:
|
|||||||
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||||
|
|
||||||
- name: Install uv and python
|
- name: Install uv and python
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
version: ${{ env.UV_VERSION }}
|
version: ${{ env.UV_VERSION }}
|
||||||
@@ -117,7 +117,7 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
lfs: true # Ensure LFS files are pulled
|
lfs: true # Ensure LFS files are pulled
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
@@ -129,7 +129,7 @@ jobs:
|
|||||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||||
|
|
||||||
- name: Install uv and python
|
- name: Install uv and python
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
version: ${{ env.UV_VERSION }}
|
version: ${{ env.UV_VERSION }}
|
||||||
|
|||||||
4
.github/workflows/trufflehog.yml
vendored
4
.github/workflows/trufflehog.yml
vendored
@@ -24,12 +24,12 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Secret Scanning
|
- name: Secret Scanning
|
||||||
uses: trufflesecurity/trufflehog@main
|
uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35
|
||||||
with:
|
with:
|
||||||
extra_args: --only-verified
|
extra_args: --only-verified
|
||||||
|
|||||||
@@ -37,18 +37,18 @@ repos:
|
|||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
|
|
||||||
- repo: https://github.com/adhtruong/mirrors-typos
|
- repo: https://github.com/adhtruong/mirrors-typos
|
||||||
rev: v1.31.1
|
rev: v1.32.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: typos
|
- id: typos
|
||||||
args: [--force-exclude]
|
args: [--force-exclude]
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v3.19.1
|
rev: v3.20.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.11.5
|
rev: v0.11.12
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
@@ -57,12 +57,12 @@ repos:
|
|||||||
|
|
||||||
##### Security #####
|
##### Security #####
|
||||||
- repo: https://github.com/gitleaks/gitleaks
|
- repo: https://github.com/gitleaks/gitleaks
|
||||||
rev: v8.24.3
|
rev: v8.27.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: gitleaks
|
- id: gitleaks
|
||||||
|
|
||||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||||
rev: v1.5.2
|
rev: v1.9.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: zizmor
|
- id: zizmor
|
||||||
|
|
||||||
|
|||||||
@@ -1,60 +0,0 @@
|
|||||||
// fmt: off
|
|
||||||
// flake8: noqa
|
|
||||||
// !/usr/bin/env python
|
|
||||||
|
|
||||||
// Copyright 2024 The HuggingFace Inc. team.
|
|
||||||
// All rights reserved.
|
|
||||||
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
syntax = "proto3";
|
|
||||||
|
|
||||||
package async_inference;
|
|
||||||
|
|
||||||
// AsyncInference: from Robot perspective
|
|
||||||
// Robot send observations to & executes action received from a remote Policy server
|
|
||||||
service AsyncInference {
|
|
||||||
// Robot -> Policy to share observations with a remote inference server
|
|
||||||
// Policy -> Robot to share actions predicted for given observations
|
|
||||||
rpc SendObservations(stream Observation) returns (Empty);
|
|
||||||
rpc StreamActions(Empty) returns (stream Action);
|
|
||||||
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
|
|
||||||
rpc Ready(Empty) returns (Empty);
|
|
||||||
}
|
|
||||||
|
|
||||||
enum TransferState {
|
|
||||||
TRANSFER_UNKNOWN = 0;
|
|
||||||
TRANSFER_BEGIN = 1;
|
|
||||||
TRANSFER_MIDDLE = 2;
|
|
||||||
TRANSFER_END = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Messages
|
|
||||||
message Observation {
|
|
||||||
// sent by Robot, to remote Policy
|
|
||||||
TransferState transfer_state = 1;
|
|
||||||
bytes data = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message Action {
|
|
||||||
// sent by remote Policy, to Robot
|
|
||||||
TransferState transfer_state = 1;
|
|
||||||
bytes data = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message PolicySetup {
|
|
||||||
// sent by Robot to remote server, to init Policy
|
|
||||||
TransferState transfer_state = 1;
|
|
||||||
bytes data = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message Empty {}
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
# fmt: off
|
|
||||||
# flake8: noqa
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
||||||
# NO CHECKED-IN PROTOBUF GENCODE
|
|
||||||
# source: async_inference.proto
|
|
||||||
# Protobuf Python Version: 5.29.0
|
|
||||||
"""Generated protocol buffer code."""
|
|
||||||
from google.protobuf import descriptor as _descriptor
|
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
||||||
from google.protobuf import runtime_version as _runtime_version
|
|
||||||
from google.protobuf import symbol_database as _symbol_database
|
|
||||||
from google.protobuf.internal import builder as _builder
|
|
||||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
||||||
_runtime_version.Domain.PUBLIC,
|
|
||||||
5,
|
|
||||||
29,
|
|
||||||
0,
|
|
||||||
'',
|
|
||||||
'async_inference.proto'
|
|
||||||
)
|
|
||||||
# @@protoc_insertion_point(imports)
|
|
||||||
|
|
||||||
_sym_db = _symbol_database.Default()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"N\n\x06\x41\x63tion\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x0bPolicySetup\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xa9\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12\x42\n\rStreamActions\x12\x16.async_inference.Empty\x1a\x17.async_inference.Action0\x01\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
|
|
||||||
|
|
||||||
_globals = globals()
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
|
|
||||||
if not _descriptor._USE_C_DESCRIPTORS:
|
|
||||||
DESCRIPTOR._loaded_options = None
|
|
||||||
_globals['_TRANSFERSTATE']._serialized_start=301
|
|
||||||
_globals['_TRANSFERSTATE']._serialized_end=397
|
|
||||||
_globals['_OBSERVATION']._serialized_start=42
|
|
||||||
_globals['_OBSERVATION']._serialized_end=125
|
|
||||||
_globals['_ACTION']._serialized_start=127
|
|
||||||
_globals['_ACTION']._serialized_end=205
|
|
||||||
_globals['_POLICYSETUP']._serialized_start=207
|
|
||||||
_globals['_POLICYSETUP']._serialized_end=290
|
|
||||||
_globals['_EMPTY']._serialized_start=292
|
|
||||||
_globals['_EMPTY']._serialized_end=299
|
|
||||||
_globals['_ASYNCINFERENCE']._serialized_start=400
|
|
||||||
_globals['_ASYNCINFERENCE']._serialized_end=697
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
|
||||||
@@ -1,236 +0,0 @@
|
|||||||
# fmt: off
|
|
||||||
# flake8: noqa
|
|
||||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
||||||
"""Client and server classes corresponding to protobuf-defined services."""
|
|
||||||
import grpc
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import async_inference_pb2 as async__inference__pb2
|
|
||||||
|
|
||||||
GRPC_GENERATED_VERSION = '1.71.0'
|
|
||||||
GRPC_VERSION = grpc.__version__
|
|
||||||
_version_not_supported = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from grpc._utilities import first_version_is_lower
|
|
||||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
|
||||||
except ImportError:
|
|
||||||
_version_not_supported = True
|
|
||||||
|
|
||||||
if _version_not_supported:
|
|
||||||
raise RuntimeError(
|
|
||||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
|
||||||
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
|
|
||||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
|
||||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
|
||||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncInferenceStub:
|
|
||||||
"""AsyncInference: from Robot perspective
|
|
||||||
Robot send observations to & executes action received from a remote Policy server
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
"""Constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: A grpc.Channel.
|
|
||||||
"""
|
|
||||||
self.SendObservations = channel.stream_unary(
|
|
||||||
'/async_inference.AsyncInference/SendObservations',
|
|
||||||
request_serializer=async__inference__pb2.Observation.SerializeToString,
|
|
||||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
self.StreamActions = channel.unary_stream(
|
|
||||||
'/async_inference.AsyncInference/StreamActions',
|
|
||||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
response_deserializer=async__inference__pb2.Action.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
self.SendPolicyInstructions = channel.unary_unary(
|
|
||||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
|
||||||
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
|
|
||||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
self.Ready = channel.unary_unary(
|
|
||||||
'/async_inference.AsyncInference/Ready',
|
|
||||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncInferenceServicer:
|
|
||||||
"""AsyncInference: from Robot perspective
|
|
||||||
Robot send observations to & executes action received from a remote Policy server
|
|
||||||
"""
|
|
||||||
|
|
||||||
def SendObservations(self, request_iterator, context):
|
|
||||||
"""Robot -> Policy to share observations with a remote inference server
|
|
||||||
Policy -> Robot to share actions predicted for given observations
|
|
||||||
"""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def StreamActions(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def SendPolicyInstructions(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Ready(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
|
|
||||||
def add_AsyncInferenceServicer_to_server(servicer, server):
|
|
||||||
rpc_method_handlers = {
|
|
||||||
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
|
||||||
servicer.SendObservations,
|
|
||||||
request_deserializer=async__inference__pb2.Observation.FromString,
|
|
||||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
),
|
|
||||||
'StreamActions': grpc.unary_stream_rpc_method_handler(
|
|
||||||
servicer.StreamActions,
|
|
||||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
response_serializer=async__inference__pb2.Action.SerializeToString,
|
|
||||||
),
|
|
||||||
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.SendPolicyInstructions,
|
|
||||||
request_deserializer=async__inference__pb2.PolicySetup.FromString,
|
|
||||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
),
|
|
||||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Ready,
|
|
||||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
|
||||||
'async_inference.AsyncInference', rpc_method_handlers)
|
|
||||||
server.add_generic_rpc_handlers((generic_handler,))
|
|
||||||
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
|
|
||||||
|
|
||||||
|
|
||||||
# This class is part of an EXPERIMENTAL API.
|
|
||||||
class AsyncInference:
|
|
||||||
"""AsyncInference: from Robot perspective
|
|
||||||
Robot send observations to & executes action received from a remote Policy server
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def SendObservations(request_iterator,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.stream_unary(
|
|
||||||
request_iterator,
|
|
||||||
target,
|
|
||||||
'/async_inference.AsyncInference/SendObservations',
|
|
||||||
async__inference__pb2.Observation.SerializeToString,
|
|
||||||
async__inference__pb2.Empty.FromString,
|
|
||||||
options,
|
|
||||||
channel_credentials,
|
|
||||||
insecure,
|
|
||||||
call_credentials,
|
|
||||||
compression,
|
|
||||||
wait_for_ready,
|
|
||||||
timeout,
|
|
||||||
metadata,
|
|
||||||
_registered_method=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def StreamActions(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_stream(
|
|
||||||
request,
|
|
||||||
target,
|
|
||||||
'/async_inference.AsyncInference/StreamActions',
|
|
||||||
async__inference__pb2.Empty.SerializeToString,
|
|
||||||
async__inference__pb2.Action.FromString,
|
|
||||||
options,
|
|
||||||
channel_credentials,
|
|
||||||
insecure,
|
|
||||||
call_credentials,
|
|
||||||
compression,
|
|
||||||
wait_for_ready,
|
|
||||||
timeout,
|
|
||||||
metadata,
|
|
||||||
_registered_method=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def SendPolicyInstructions(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(
|
|
||||||
request,
|
|
||||||
target,
|
|
||||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
|
||||||
async__inference__pb2.PolicySetup.SerializeToString,
|
|
||||||
async__inference__pb2.Empty.FromString,
|
|
||||||
options,
|
|
||||||
channel_credentials,
|
|
||||||
insecure,
|
|
||||||
call_credentials,
|
|
||||||
compression,
|
|
||||||
wait_for_ready,
|
|
||||||
timeout,
|
|
||||||
metadata,
|
|
||||||
_registered_method=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Ready(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(
|
|
||||||
request,
|
|
||||||
target,
|
|
||||||
'/async_inference.AsyncInference/Ready',
|
|
||||||
async__inference__pb2.Empty.SerializeToString,
|
|
||||||
async__inference__pb2.Empty.FromString,
|
|
||||||
options,
|
|
||||||
channel_credentials,
|
|
||||||
insecure,
|
|
||||||
call_credentials,
|
|
||||||
compression,
|
|
||||||
wait_for_ready,
|
|
||||||
timeout,
|
|
||||||
metadata,
|
|
||||||
_registered_method=True)
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""Server/Client side: Sometimes you just want the environment to wait a tiny bit"""
|
|
||||||
|
|
||||||
idle_wait = 0.01
|
|
||||||
|
|
||||||
"""Client side: The environment evolves with a time resolution equal to environment_dt"""
|
|
||||||
environment_dt = 1 / 30
|
|
||||||
|
|
||||||
"""Server side: Running inference on (at most) environment_dt"""
|
|
||||||
inference_latency = environment_dt
|
|
||||||
|
|
||||||
"""Supported policies"""
|
|
||||||
supported_policies = ["act", "smolvla"]
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
import logging
|
|
||||||
import logging.handlers
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(prefix: str, info_bracket: str):
|
|
||||||
"""Sets up logging"""
|
|
||||||
# Create logs directory if it doesn't exist
|
|
||||||
os.makedirs("logs", exist_ok=True)
|
|
||||||
|
|
||||||
# Delete any existing prefix_* log files
|
|
||||||
for old_log_file in os.listdir("logs"):
|
|
||||||
if old_log_file.startswith(prefix) and old_log_file.endswith(".log"):
|
|
||||||
try:
|
|
||||||
os.remove(os.path.join("logs", old_log_file))
|
|
||||||
print(f"Deleted old log file: {old_log_file}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to delete old log file {old_log_file}: {e}")
|
|
||||||
|
|
||||||
# Set up logging with both console and file output
|
|
||||||
logger = logging.getLogger(prefix)
|
|
||||||
# Prevent propagation to root logger to avoid duplicate messages
|
|
||||||
logger.propagate = False
|
|
||||||
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
# Console handler
|
|
||||||
console_handler = logging.StreamHandler()
|
|
||||||
console_handler.setFormatter(
|
|
||||||
logging.Formatter(
|
|
||||||
f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
logger.addHandler(console_handler)
|
|
||||||
|
|
||||||
# File handler - creates a new log file for each run
|
|
||||||
file_handler = logging.handlers.RotatingFileHandler(
|
|
||||||
f"logs/policy_server_{int(time.time())}.log",
|
|
||||||
maxBytes=10 * 1024 * 1024, # 10MB
|
|
||||||
backupCount=5,
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(
|
|
||||||
logging.Formatter(
|
|
||||||
f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
class TimedData:
|
|
||||||
def __init__(self, timestamp: float, data: Any, timestep: int):
|
|
||||||
"""Initialize a TimedData object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestamp: Unix timestamp relative to data's creation.
|
|
||||||
data: The actual data to wrap a timestamp around.
|
|
||||||
timestep: The timestep of the data.
|
|
||||||
"""
|
|
||||||
self.timestamp = timestamp
|
|
||||||
self.data = data
|
|
||||||
self.timestep = timestep
|
|
||||||
|
|
||||||
def get_data(self):
|
|
||||||
return self.data
|
|
||||||
|
|
||||||
def get_timestamp(self):
|
|
||||||
return self.timestamp
|
|
||||||
|
|
||||||
def get_timestep(self):
|
|
||||||
return self.timestep
|
|
||||||
|
|
||||||
|
|
||||||
class TimedAction(TimedData):
|
|
||||||
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
|
|
||||||
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
|
|
||||||
|
|
||||||
def get_action(self):
|
|
||||||
return self.get_data()
|
|
||||||
|
|
||||||
|
|
||||||
class TimedObservation(TimedData):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
timestamp: float,
|
|
||||||
observation: dict[str, torch.Tensor],
|
|
||||||
timestep: int,
|
|
||||||
transfer_state: int = 0,
|
|
||||||
must_go: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
|
|
||||||
self.transfer_state = transfer_state
|
|
||||||
self.must_go = must_go
|
|
||||||
|
|
||||||
def get_observation(self):
|
|
||||||
return self.get_data()
|
|
||||||
|
|
||||||
|
|
||||||
class TinyPolicyConfig:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
policy_type: str = "act",
|
|
||||||
pretrained_name_or_path: str = "fracapuano/act_so100_test",
|
|
||||||
device: str = "cpu",
|
|
||||||
):
|
|
||||||
self.policy_type = policy_type
|
|
||||||
self.pretrained_name_or_path = pretrained_name_or_path
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
|
||||||
"""Check if two observation states are similar, under a tolerance threshold"""
|
|
||||||
return torch.linalg.norm(obs1_state - obs2_state) < atol
|
|
||||||
|
|
||||||
|
|
||||||
def observations_similar(obs1: TimedObservation, obs2: TimedObservation, atol: float = 1) -> bool:
|
|
||||||
"""Check if two observations are similar, under a tolerance threshold"""
|
|
||||||
obs1_state = obs1.get_observation()["observation.state"]
|
|
||||||
obs2_state = obs2.get_observation()["observation.state"]
|
|
||||||
|
|
||||||
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
|
||||||
@@ -1,429 +0,0 @@
|
|||||||
import itertools
|
|
||||||
import pickle # nosec
|
|
||||||
import time
|
|
||||||
from concurrent import futures
|
|
||||||
from queue import Queue
|
|
||||||
from typing import Generator, List, Optional
|
|
||||||
|
|
||||||
import async_inference_pb2 # type: ignore
|
|
||||||
import async_inference_pb2_grpc # type: ignore
|
|
||||||
import grpc
|
|
||||||
import torch
|
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
from lerobot.common.policies.factory import get_policy_class
|
|
||||||
from lerobot.scripts.server.constants import environment_dt, idle_wait, inference_latency, supported_policies
|
|
||||||
from lerobot.scripts.server.helpers import (
|
|
||||||
TimedAction,
|
|
||||||
TimedObservation,
|
|
||||||
TinyPolicyConfig,
|
|
||||||
observations_similar,
|
|
||||||
setup_logging,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|
||||||
prefix = "policy_server"
|
|
||||||
info_bracket = "SERVER"
|
|
||||||
logger = setup_logging(prefix, info_bracket)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Initialize dataset action generator (to debug this first version, will be removed in the future)
|
|
||||||
self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset())
|
|
||||||
|
|
||||||
self._setup_server()
|
|
||||||
|
|
||||||
self.actions_per_chunk = 20
|
|
||||||
self.actions_overlap = 10
|
|
||||||
|
|
||||||
self.running = True
|
|
||||||
|
|
||||||
def _setup_server(self) -> None:
|
|
||||||
"""Flushes server state when new client connects."""
|
|
||||||
# only running inference on the latest observation received by the server
|
|
||||||
self.observation_queue = Queue(maxsize=1)
|
|
||||||
self._predicted_timesteps = set()
|
|
||||||
self._predicted_observations = Queue(maxsize=1)
|
|
||||||
|
|
||||||
def Ready(self, request, context): # noqa: N802
|
|
||||||
client_id = context.peer()
|
|
||||||
self.logger.info(f"Client {client_id} connected and ready")
|
|
||||||
self._setup_server()
|
|
||||||
|
|
||||||
return async_inference_pb2.Empty()
|
|
||||||
|
|
||||||
def SendPolicyInstructions(self, request, context): # noqa: N802
|
|
||||||
"""Receive policy instructions from the robot client"""
|
|
||||||
client_id = context.peer()
|
|
||||||
self.logger.debug(f"Receiving policy instructions from {client_id}")
|
|
||||||
|
|
||||||
policy_specs = pickle.loads(request.data) # nosec
|
|
||||||
assert isinstance(policy_specs, TinyPolicyConfig), (
|
|
||||||
f"Policy specs must be a TinyPolicyConfig. Got {type(policy_specs)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"Policy type: {policy_specs.policy_type} | "
|
|
||||||
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
|
|
||||||
f"Device: {policy_specs.device}"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert policy_specs.policy_type in supported_policies, (
|
|
||||||
f"Policy type {policy_specs.policy_type} not supported. Supported policies: {supported_policies}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.device = policy_specs.device
|
|
||||||
self.policy_type = policy_specs.policy_type # act, pi0, etc.
|
|
||||||
|
|
||||||
policy_class = get_policy_class(self.policy_type)
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
|
||||||
self.policy.to(self.device)
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
|
||||||
|
|
||||||
return async_inference_pb2.Empty()
|
|
||||||
|
|
||||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
|
||||||
"""Receive observations from the robot client"""
|
|
||||||
client_id = context.peer()
|
|
||||||
self.logger.debug(f"Receiving observations from {client_id}")
|
|
||||||
|
|
||||||
for observation in request_iterator:
|
|
||||||
receive_time = time.time()
|
|
||||||
timed_observation = pickle.loads(observation.data) # nosec
|
|
||||||
deserialize_time = time.time()
|
|
||||||
|
|
||||||
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
|
|
||||||
|
|
||||||
if not self._maybe_enqueue_observation(timed_observation):
|
|
||||||
continue
|
|
||||||
|
|
||||||
queue_time = time.time()
|
|
||||||
|
|
||||||
obs_timestep = timed_observation.get_timestep()
|
|
||||||
obs_timestamp = timed_observation.get_timestamp()
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"Received observation #{obs_timestep} | "
|
|
||||||
f"Client timestamp: {obs_timestamp:.6f} | "
|
|
||||||
f"Server timestamp: {receive_time:.6f} | "
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hasattr(self, "previous_obs_timestamp"):
|
|
||||||
self.previous_obs_timestamp = obs_timestamp
|
|
||||||
|
|
||||||
self.logger.debug(
|
|
||||||
f"1/DeltaObsT (~frequency): {1 / (1e-6 + obs_timestamp - self.previous_obs_timestamp):.6f} Hz| "
|
|
||||||
f"Network latency: {receive_time - obs_timestamp:.6f}s | "
|
|
||||||
f"Deserialization time: {deserialize_time - receive_time:.6f}s | "
|
|
||||||
f"Queue time: {queue_time - deserialize_time:.6f}s | "
|
|
||||||
)
|
|
||||||
|
|
||||||
self.previous_obs_timestamp = obs_timestamp
|
|
||||||
|
|
||||||
return async_inference_pb2.Empty()
|
|
||||||
|
|
||||||
def StreamActions(self, request, context): # noqa: N802
|
|
||||||
"""Stream actions to the robot client"""
|
|
||||||
client_id = context.peer()
|
|
||||||
self.logger.debug(f"Client {client_id} connected for action streaming")
|
|
||||||
|
|
||||||
# Generate action based on the most recent observation and its timestep
|
|
||||||
try:
|
|
||||||
obs = self.observation_queue.get()
|
|
||||||
self.logger.info(
|
|
||||||
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
|
|
||||||
)
|
|
||||||
|
|
||||||
if obs:
|
|
||||||
self.last_predicted_obs = obs
|
|
||||||
self._predicted_timesteps.add(obs.get_timestep())
|
|
||||||
start_time = time.time()
|
|
||||||
action_chunk = self._predict_action_chunk(obs)
|
|
||||||
# action_chunk = self._read_action_chunk(obs)
|
|
||||||
inference_time = time.time() - start_time
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
action_bytes = pickle.dumps(action_chunk) # nosec
|
|
||||||
serialize_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Create and return the Action
|
|
||||||
action = async_inference_pb2.Action(transfer_state=obs.transfer_state, data=action_bytes)
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"Action chunk #{obs.get_timestep()} generated | Inference time: {inference_time:.6f}s |"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.debug(
|
|
||||||
f"Action chunk #{obs.get_timestep()} generated | "
|
|
||||||
f"Inference time: {inference_time:.6f}s |"
|
|
||||||
f"Serialize time: {serialize_time:.6f}s |"
|
|
||||||
f"Total time: {inference_time + serialize_time:.6f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield action
|
|
||||||
else:
|
|
||||||
self.logger.warning("No observation in queue yet!")
|
|
||||||
time.sleep(idle_wait)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error in StreamActions: {e}")
|
|
||||||
|
|
||||||
return async_inference_pb2.Empty()
|
|
||||||
|
|
||||||
def _enqueue_and_go(self, obs: TimedObservation):
|
|
||||||
# If queue is full, get the old observation to make room
|
|
||||||
if self.observation_queue.full():
|
|
||||||
# pops from queue
|
|
||||||
_ = self.observation_queue.get_nowait()
|
|
||||||
self.logger.debug("Observation queue was full, removed oldest observation")
|
|
||||||
|
|
||||||
# Now put the new observation (never blocks as queue is non-full here)
|
|
||||||
self.observation_queue.put(obs)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
|
||||||
if obs.get_timestep() in self._predicted_timesteps:
|
|
||||||
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
|
|
||||||
return False
|
|
||||||
|
|
||||||
elif observations_similar(obs, previous_obs, atol=1):
|
|
||||||
self.logger.debug(
|
|
||||||
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _maybe_enqueue_observation(self, obs: TimedObservation) -> bool:
|
|
||||||
"""Enqueue an observation if it must go through processing, otherwise skip it.
|
|
||||||
Observations not in queue are never run through the policy network"""
|
|
||||||
|
|
||||||
if obs.must_go or not hasattr(self, "last_predicted_obs"):
|
|
||||||
self.logger.info(f"[MUST GO] Enqueued observation #{obs.get_timestep()} for direct processing!")
|
|
||||||
return self._enqueue_and_go(obs)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if self._obs_sanity_checks(obs, self.last_predicted_obs):
|
|
||||||
return self._enqueue_and_go(obs)
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
|
|
||||||
"""Turn a chunk of actions into a list of TimedAction instances,
|
|
||||||
with the first action corresponding to t_0 and the rest corresponding to
|
|
||||||
t_0 + i*environment_dt for i in range(len(action_chunk))
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk)
|
|
||||||
]
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _run_act_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
||||||
"""Run ACT-like policies"""
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# prepare observation for policy forward pass
|
|
||||||
batch = self.policy.normalize_inputs(observation)
|
|
||||||
normalize_time = time.time()
|
|
||||||
self.logger.debug(f"Observation normalization time: {normalize_time - start_time:.6f}s")
|
|
||||||
|
|
||||||
if self.policy.config.image_features:
|
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
|
||||||
batch["observation.images"] = [batch[key] for key in self.policy.config.image_features]
|
|
||||||
prep_time = time.time()
|
|
||||||
self.logger.debug(f"Observation image preparation time: {prep_time - normalize_time:.6f}s")
|
|
||||||
|
|
||||||
# forward pass outputs up to policy.config.n_action_steps != actions_per_chunk
|
|
||||||
actions = self.policy.model(batch)[0][:, : self.actions_per_chunk]
|
|
||||||
|
|
||||||
actions = self.policy.unnormalize_outputs({"action": actions})["action"]
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
self.logger.info(f"[ACT] Action chunk generation total time: {end_time - start_time:.6f}s")
|
|
||||||
|
|
||||||
return actions
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _run_pi0_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
||||||
"""Run PI0-like policies"""
|
|
||||||
raise NotImplementedError("PI0 policy not implemented yet")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _run_smolvla_policy(
|
|
||||||
self, observation: dict[str, torch.Tensor], noise: Optional[torch.Tensor] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Run smolvla-like policies"""
|
|
||||||
observation = self.policy.normalize_inputs(observation)
|
|
||||||
|
|
||||||
images, img_masks = self.policy.prepare_images(observation)
|
|
||||||
state = self.policy.prepare_state(observation)
|
|
||||||
lang_tokens, lang_masks = self.policy.prepare_language(observation)
|
|
||||||
|
|
||||||
actions = self.policy.model.sample_actions(
|
|
||||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
|
||||||
)
|
|
||||||
|
|
||||||
# Unpad actions
|
|
||||||
original_action_dim = self.policy.config.action_feature.shape[0]
|
|
||||||
actions = actions[:, :, :original_action_dim]
|
|
||||||
|
|
||||||
actions = self.policy.unnormalize_outputs(
|
|
||||||
{"action": actions, "robot_type": [self.policy.config.robot_type]}
|
|
||||||
)["action"]
|
|
||||||
|
|
||||||
return actions
|
|
||||||
|
|
||||||
def _get_action_chunk(
|
|
||||||
self, observation: dict[str, torch.Tensor], policy_type: str = "act"
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Get an action chunk from the policy"""
|
|
||||||
if policy_type == "act":
|
|
||||||
return self._run_act_policy(observation)
|
|
||||||
elif policy_type == "smolvla":
|
|
||||||
return self._run_smolvla_policy(observation)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Policy class {policy_type} not supported")
|
|
||||||
|
|
||||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
|
||||||
"""Predict an action based on the observation"""
|
|
||||||
"""1. Prepare observation"""
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"robot_type": [self.policy.config.robot_type],
|
|
||||||
}
|
|
||||||
for k, v in observation_t.get_observation().items():
|
|
||||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions
|
|
||||||
if "image" in k:
|
|
||||||
# Add batch dimension first, then reorder to NCHW format, then normalize to [0, 1]
|
|
||||||
observation[k] = (
|
|
||||||
v.unsqueeze(0).permute(0, 3, 1, 2).to(self.device, non_blocking=True) / 255.0
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
observation[k] = v.unsqueeze(0).to(self.device, non_blocking=True)
|
|
||||||
else:
|
|
||||||
observation[k] = v # textual instructions are passed as a list of strings
|
|
||||||
|
|
||||||
prep_time = time.time()
|
|
||||||
self.logger.debug(f"Observation preparation time: {prep_time - start_time:.6f}s")
|
|
||||||
|
|
||||||
"""2. Get action chunk"""
|
|
||||||
action_tensor = self._get_action_chunk(observation, self.policy_type)
|
|
||||||
action_tensor = action_tensor.squeeze(0)
|
|
||||||
|
|
||||||
# Move to CPU before serializing
|
|
||||||
action_tensor = action_tensor.cpu()
|
|
||||||
|
|
||||||
post_inference_time = time.time()
|
|
||||||
self.logger.debug(f"Post-inference processing start: {post_inference_time - prep_time:.6f}s")
|
|
||||||
|
|
||||||
if action_tensor.dim() == 1:
|
|
||||||
# No chunk dimension, so repeat action to create a (dummy) chunk of actions
|
|
||||||
action_tensor = action_tensor.repeat(self.actions_per_chunk, 1)
|
|
||||||
|
|
||||||
action_chunk = self._time_action_chunk(
|
|
||||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
|
||||||
)
|
|
||||||
|
|
||||||
chunk_time = time.time()
|
|
||||||
self.logger.debug(f"Action chunk creation time: {chunk_time - post_inference_time:.6f}s")
|
|
||||||
time.sleep(
|
|
||||||
max(0, inference_latency - max(0, chunk_time - start_time))
|
|
||||||
) # sleep to control inference latency
|
|
||||||
|
|
||||||
return action_chunk
|
|
||||||
|
|
||||||
def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]:
|
|
||||||
"""Stream chunks of actions from a prerecorded dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generator that yields chunks of actions from the dataset
|
|
||||||
"""
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch")
|
|
||||||
|
|
||||||
# 1. Select the action column only, where you will find tensors with 6 elements
|
|
||||||
actions = dataset["action"]
|
|
||||||
action_indices = torch.arange(len(actions))
|
|
||||||
|
|
||||||
# 2. Chunk the iterable of tensors into chunks with 10 elements each
|
|
||||||
# sending only first element for debugging
|
|
||||||
indices_chunks = action_indices.unfold(
|
|
||||||
0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
for idx_chunk in indices_chunks:
|
|
||||||
yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :]
|
|
||||||
|
|
||||||
def _read_action_chunk(self, observation: Optional[TimedObservation] = None) -> list[TimedAction]:
|
|
||||||
"""Dummy function for predicting action chunk given observation.
|
|
||||||
|
|
||||||
Instead of computing actions on-the-fly, this method streams
|
|
||||||
actions from a prerecorded dataset.
|
|
||||||
"""
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
|
|
||||||
)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
if not observation:
|
|
||||||
observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0)
|
|
||||||
|
|
||||||
# Get chunk of actions from the generator
|
|
||||||
actions_chunk = next(self.action_generator)
|
|
||||||
|
|
||||||
# Return a list of TimedActions, with timestamps starting from the observation timestamp
|
|
||||||
actions_chunk = self._time_action_chunk(
|
|
||||||
observation.get_timestamp(), actions_chunk, observation.get_timestep()
|
|
||||||
)
|
|
||||||
|
|
||||||
chunk_time = time.time()
|
|
||||||
self.logger.debug(f"Action chunk creation time: {chunk_time - start_time:.6f}s")
|
|
||||||
|
|
||||||
# slow action generation, emulates inference time
|
|
||||||
time.sleep(max(0, inference_latency - max(0, chunk_time - start_time)))
|
|
||||||
|
|
||||||
return actions_chunk
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
"""Stop the server"""
|
|
||||||
self.running = False
|
|
||||||
self.logger.info("Server stopping...")
|
|
||||||
|
|
||||||
|
|
||||||
def serve():
|
|
||||||
port = 8080
|
|
||||||
# Create the server instance first
|
|
||||||
policy_server = PolicyServer()
|
|
||||||
|
|
||||||
# Setup and start gRPC server
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
|
||||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
|
||||||
server.add_insecure_port(f"[::]:{port}")
|
|
||||||
server.start()
|
|
||||||
policy_server.logger.info(f"PolicyServer started on port {port}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use the running attribute to control server lifetime
|
|
||||||
while policy_server.running:
|
|
||||||
time.sleep(1) # Check every second instead of sleeping indefinitely
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
policy_server.stop()
|
|
||||||
policy_server.logger.info("Keyboard interrupt received")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
serve()
|
|
||||||
@@ -1,608 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import pickle # nosec
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from queue import Empty, Queue
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import async_inference_pb2 # type: ignore
|
|
||||||
import async_inference_pb2_grpc # type: ignore
|
|
||||||
import grpc
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
|
||||||
from lerobot.scripts.server.constants import environment_dt, idle_wait
|
|
||||||
from lerobot.scripts.server.helpers import TimedAction, TimedObservation, TinyPolicyConfig, setup_logging
|
|
||||||
|
|
||||||
|
|
||||||
class RobotClient:
|
|
||||||
prefix = "robot_client"
|
|
||||||
info_bracket = "CLIENT"
|
|
||||||
logger = setup_logging(prefix, info_bracket)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
server_address: Optional[str] = None,
|
|
||||||
policy_type: str = "smolvla",
|
|
||||||
pretrained_name_or_path: str = "lerobot/smolvla_base",
|
|
||||||
policy_device: str = "cuda",
|
|
||||||
chunk_size_threshold: float = 0.5,
|
|
||||||
robot: str = "so100",
|
|
||||||
):
|
|
||||||
# Use environment variable if server_address is not provided
|
|
||||||
if server_address is None:
|
|
||||||
server_address = os.getenv("SERVER_ADDRESS", "localhost:8080")
|
|
||||||
self.logger.info(f"No server address provided, using default address: {server_address}")
|
|
||||||
|
|
||||||
self.policy_config = TinyPolicyConfig(policy_type, pretrained_name_or_path, policy_device)
|
|
||||||
self.channel = grpc.insecure_channel(server_address)
|
|
||||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
|
||||||
self.logger.info(f"Initializing client to connect to server at {server_address}")
|
|
||||||
|
|
||||||
self.running = False
|
|
||||||
self.must_go = True # does the observation qualify for direct processing on the policy server?
|
|
||||||
|
|
||||||
self.latest_action = -1
|
|
||||||
self.action_chunk_size = -1
|
|
||||||
|
|
||||||
self._chunk_size_threshold = chunk_size_threshold
|
|
||||||
|
|
||||||
self.action_queue = Queue()
|
|
||||||
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
self.robot = make_robot(robot)
|
|
||||||
self.robot.connect()
|
|
||||||
|
|
||||||
connect_time = time.time()
|
|
||||||
self.logger.info(f"Robot connection time: {connect_time - start_time:.4f}s")
|
|
||||||
|
|
||||||
time.sleep(idle_wait) # sleep waiting for cameras to activate
|
|
||||||
self.logger.info("Robot connected and ready")
|
|
||||||
|
|
||||||
def timestamps(self):
|
|
||||||
"""Get the timestamps of the actions in the queue"""
|
|
||||||
return sorted([action.get_timestep() for action in self.action_queue.queue])
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""Start the robot client and connect to the policy server"""
|
|
||||||
try:
|
|
||||||
# client-server handshake
|
|
||||||
start_time = time.time()
|
|
||||||
self.stub.Ready(async_inference_pb2.Empty())
|
|
||||||
end_time = time.time()
|
|
||||||
self.logger.info(f"Connected to policy server in {end_time - start_time:.4f}s")
|
|
||||||
|
|
||||||
# send policy instructions
|
|
||||||
policy_config_bytes = pickle.dumps(self.policy_config)
|
|
||||||
policy_setup = async_inference_pb2.PolicySetup(
|
|
||||||
transfer_state=async_inference_pb2.TRANSFER_BEGIN, data=policy_config_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.info("Sending policy instructions to policy server")
|
|
||||||
self.logger.info(
|
|
||||||
f"Policy type: {self.policy_config.policy_type} | "
|
|
||||||
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
|
|
||||||
f"Device: {self.policy_config.device}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.stub.SendPolicyInstructions(policy_setup)
|
|
||||||
|
|
||||||
self.running = True
|
|
||||||
self.available_actions_size = []
|
|
||||||
return True
|
|
||||||
|
|
||||||
except grpc.RpcError as e:
|
|
||||||
self.logger.error(f"Failed to connect to policy server: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
"""Stop the robot client"""
|
|
||||||
self.running = False
|
|
||||||
|
|
||||||
self.robot.disconnect()
|
|
||||||
self.logger.info("Robot disconnected")
|
|
||||||
|
|
||||||
self.channel.close()
|
|
||||||
self.logger.info("Client stopped, channel closed")
|
|
||||||
|
|
||||||
def send_observation(
|
|
||||||
self,
|
|
||||||
obs: TimedObservation,
|
|
||||||
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
|
|
||||||
) -> bool:
|
|
||||||
"""Send observation to the policy server.
|
|
||||||
Returns True if the observation was sent successfully, False otherwise."""
|
|
||||||
if not self.running:
|
|
||||||
self.logger.warning("Client not running")
|
|
||||||
return False
|
|
||||||
|
|
||||||
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
observation_bytes = pickle.dumps(obs)
|
|
||||||
serialize_time = time.time()
|
|
||||||
self.logger.debug(f"Observation serialization time: {serialize_time - start_time:.6f}s")
|
|
||||||
|
|
||||||
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
|
|
||||||
|
|
||||||
try:
|
|
||||||
send_start = time.time()
|
|
||||||
_ = self.stub.SendObservations(iter([observation]))
|
|
||||||
send_end = time.time()
|
|
||||||
|
|
||||||
obs_timestep = obs.get_timestep()
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"Sent observation #{obs_timestep} | "
|
|
||||||
f"Serialize time: {serialize_time - start_time:.6f}s | "
|
|
||||||
f"Network time: {send_end - send_start:.6f}s | "
|
|
||||||
f"Total time: {send_end - start_time:.6f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.last_obs_sent_time = send_end
|
|
||||||
return True
|
|
||||||
|
|
||||||
except grpc.RpcError as e:
|
|
||||||
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _validate_action(self, action: TimedAction):
|
|
||||||
"""Received actions are keps only when they have been produced for now or later, never before"""
|
|
||||||
return not action.get_timestep() <= self.latest_action
|
|
||||||
|
|
||||||
def _inspect_action_queue(self):
|
|
||||||
queue_size = self.action_queue.qsize()
|
|
||||||
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
|
|
||||||
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
|
|
||||||
return queue_size, timestamps
|
|
||||||
|
|
||||||
def _update_action_queue(self, actions: list[TimedAction]):
|
|
||||||
"""Update the action queue with new actions, without ever emptying the queue"""
|
|
||||||
|
|
||||||
new_queue = Queue()
|
|
||||||
for action in actions:
|
|
||||||
if self._validate_action(action):
|
|
||||||
new_queue.put(action)
|
|
||||||
|
|
||||||
self.action_queue = new_queue
|
|
||||||
|
|
||||||
def _aggregate_action_queues(
|
|
||||||
self,
|
|
||||||
incoming_actions: list[TimedAction],
|
|
||||||
aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
|
||||||
):
|
|
||||||
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
|
|
||||||
# TODO(fracapuano): move outside of the function and make aggregate_fn an always required argument
|
|
||||||
if not aggregate_fn:
|
|
||||||
# default aggregate function: take the latest action
|
|
||||||
def aggregate_fn(x1, x2):
|
|
||||||
return x2
|
|
||||||
|
|
||||||
action_intersections: list[torch.Tensor] = []
|
|
||||||
current_action_queue = {
|
|
||||||
action.get_timestep(): action.get_action() for action in self.action_queue.queue
|
|
||||||
}
|
|
||||||
|
|
||||||
for new_action in incoming_actions:
|
|
||||||
if new_action.get_timestep() in current_action_queue:
|
|
||||||
# TODO(fracapuano): There is probably a way to do this with broadcasting of the two action tensors
|
|
||||||
action_intersections.append(
|
|
||||||
TimedAction(
|
|
||||||
timestamp=new_action.get_timestamp(),
|
|
||||||
action=aggregate_fn(
|
|
||||||
current_action_queue[new_action.get_timestep()], new_action.get_action()
|
|
||||||
),
|
|
||||||
timestep=new_action.get_timestep(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
action_intersections.append(new_action)
|
|
||||||
|
|
||||||
new_queue = Queue()
|
|
||||||
for action in action_intersections:
|
|
||||||
if self._validate_action(action):
|
|
||||||
new_queue.put(action)
|
|
||||||
|
|
||||||
self.action_queue = new_queue
|
|
||||||
|
|
||||||
def _clear_action_queue(self):
|
|
||||||
"""Clear the existing queue"""
|
|
||||||
while not self.action_queue.empty():
|
|
||||||
try:
|
|
||||||
self.action_queue.get_nowait()
|
|
||||||
except Empty:
|
|
||||||
break
|
|
||||||
|
|
||||||
def _fill_action_queue(self, actions: list[TimedAction]):
|
|
||||||
"""Fill the action queue with incoming valid actions"""
|
|
||||||
start_time = time.time()
|
|
||||||
valid_count = 0
|
|
||||||
|
|
||||||
for action in actions:
|
|
||||||
if self._validate_action(action):
|
|
||||||
self.action_queue.put(action)
|
|
||||||
valid_count += 1
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
self.logger.debug(
|
|
||||||
f"Queue filled: {valid_count}/{len(actions)} valid actions added in {end_time - start_time:.6f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
|
|
||||||
self._clear_action_queue()
|
|
||||||
self._fill_action_queue(actions)
|
|
||||||
|
|
||||||
def receive_actions(self):
|
|
||||||
"""Receive actions from the policy server"""
|
|
||||||
# Wait at barrier for synchronized start
|
|
||||||
self.start_barrier.wait()
|
|
||||||
self.logger.info("Action receiving thread starting")
|
|
||||||
|
|
||||||
while self.running:
|
|
||||||
try:
|
|
||||||
# Use StreamActions to get a stream of actions from the server
|
|
||||||
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
|
|
||||||
receive_time = time.time()
|
|
||||||
|
|
||||||
# Deserialize bytes back into list[TimedAction]
|
|
||||||
deserialize_start = time.time()
|
|
||||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
|
||||||
deserialize_end = time.time()
|
|
||||||
|
|
||||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
self.logger.info(f"Current latest action: {self.latest_action}")
|
|
||||||
|
|
||||||
# Get queue state before changes
|
|
||||||
old_size, old_timesteps = self._inspect_action_queue()
|
|
||||||
if not old_timesteps:
|
|
||||||
old_timesteps = [self.latest_action] # queue was empty
|
|
||||||
|
|
||||||
# Log incoming actions
|
|
||||||
incoming_timesteps = [a.get_timestep() for a in timed_actions]
|
|
||||||
|
|
||||||
# Calculate network latency if we have matching observations
|
|
||||||
if len(timed_actions) > 0:
|
|
||||||
first_action_timestep = timed_actions[0].get_timestep()
|
|
||||||
server_to_client_latency = receive_time - self.last_obs_sent_time
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"Received action chunk for step #{first_action_timestep} | "
|
|
||||||
f"Latest action: #{self.latest_action} | "
|
|
||||||
f"Network latency (server->client): {server_to_client_latency:.6f}s | "
|
|
||||||
f"Deserialization time: {deserialize_end - deserialize_start:.6f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update action queue
|
|
||||||
start_time = time.time()
|
|
||||||
self._update_action_queue(timed_actions)
|
|
||||||
queue_update_time = time.time() - start_time
|
|
||||||
|
|
||||||
self.must_go = (
|
|
||||||
True # after receiving actions, next empty queue triggers must-go processing!
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get queue state after changes
|
|
||||||
new_size, new_timesteps = self._inspect_action_queue()
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"Queue update complete ({queue_update_time:.6f}s) | "
|
|
||||||
f"Before: {old_size} items | "
|
|
||||||
f"After: {new_size} items | "
|
|
||||||
)
|
|
||||||
self.logger.info(
|
|
||||||
f"Latest action: {self.latest_action} | "
|
|
||||||
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
|
|
||||||
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
|
||||||
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except grpc.RpcError as e:
|
|
||||||
self.logger.error(f"Error receiving actions: {e}")
|
|
||||||
# Avoid tight loop on action receiver error
|
|
||||||
time.sleep(idle_wait)
|
|
||||||
|
|
||||||
def _actions_available(self):
|
|
||||||
"""Check if there are actions available in the queue"""
|
|
||||||
return not self.action_queue.empty()
|
|
||||||
|
|
||||||
def _get_next_action(self) -> Optional[TimedAction]:
|
|
||||||
"""Get the next action from the queue"""
|
|
||||||
try:
|
|
||||||
action = self.action_queue.get_nowait()
|
|
||||||
return action
|
|
||||||
|
|
||||||
except Empty:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _perform_action(self, timed_action: TimedAction):
|
|
||||||
self.robot.send_action(timed_action.get_action())
|
|
||||||
self.latest_action = timed_action.get_timestep()
|
|
||||||
|
|
||||||
self.logger.debug(
|
|
||||||
f"Ts={timed_action.get_timestamp()} | "
|
|
||||||
f"Action #{timed_action.get_timestep()} performed | "
|
|
||||||
f"Queue size: {self.action_queue.qsize()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def execute_actions(self):
|
|
||||||
"""Continuously execute actions from the queue"""
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2)
|
|
||||||
# Wait at barrier for synchronized start
|
|
||||||
self.start_barrier.wait()
|
|
||||||
time.sleep(idle_wait) # wait for observation capture to start
|
|
||||||
|
|
||||||
self.logger.info("Action execution thread starting")
|
|
||||||
|
|
||||||
while self.running:
|
|
||||||
# constantly monitor the size of the action queue
|
|
||||||
self.available_actions_size.append(self.action_queue.qsize())
|
|
||||||
|
|
||||||
if self._actions_available():
|
|
||||||
timed_action = self._get_next_action()
|
|
||||||
self._perform_action(timed_action)
|
|
||||||
|
|
||||||
time.sleep(environment_dt)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.logger.debug("No action available | Sleeping")
|
|
||||||
time.sleep(idle_wait)
|
|
||||||
|
|
||||||
def stream_observations(self, get_observation_fn):
|
|
||||||
"""Continuously stream observations to the server"""
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2)
|
|
||||||
|
|
||||||
# Wait at barrier for synchronized start
|
|
||||||
self.start_barrier.wait()
|
|
||||||
self.logger.info("Observation streaming thread starting")
|
|
||||||
|
|
||||||
while self.running:
|
|
||||||
try:
|
|
||||||
# Get serialized observation bytes from the function
|
|
||||||
start_time = time.time()
|
|
||||||
observation = get_observation_fn()
|
|
||||||
obs_capture_time = time.time() - start_time
|
|
||||||
|
|
||||||
self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s")
|
|
||||||
|
|
||||||
if not hasattr(self, "last_obs_timestamp"):
|
|
||||||
self.last_obs_timestamp = observation.get_timestamp()
|
|
||||||
|
|
||||||
obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp()
|
|
||||||
self.logger.info(
|
|
||||||
f"Ts={obs_timestamp} | "
|
|
||||||
f"Captured observation #{obs_timestep} | "
|
|
||||||
f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.last_obs_timestamp = obs_timestamp
|
|
||||||
|
|
||||||
# Set appropriate transfer state
|
|
||||||
if obs_timestep == 0:
|
|
||||||
state = async_inference_pb2.TRANSFER_BEGIN
|
|
||||||
else:
|
|
||||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
|
||||||
|
|
||||||
time.sleep(environment_dt)
|
|
||||||
self.send_observation(observation, state)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error in observation sender: {e}")
|
|
||||||
time.sleep(idle_wait)
|
|
||||||
|
|
||||||
def control_loop_action(self):
|
|
||||||
"""Reading and performing actions in local queue"""
|
|
||||||
self.available_actions_size.append(self.action_queue.qsize())
|
|
||||||
if self._actions_available():
|
|
||||||
# Get action from queue
|
|
||||||
get_start = time.time()
|
|
||||||
timed_action = self._get_next_action()
|
|
||||||
get_end = time.time() - get_start
|
|
||||||
|
|
||||||
self.logger.debug(
|
|
||||||
f"Popping action from queue to perform took {get_end:.6f}s | "
|
|
||||||
f"Queue size: {self.action_queue.qsize()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._perform_action(timed_action)
|
|
||||||
|
|
||||||
def _ready_to_send_observation(self):
|
|
||||||
"""Flags when the client is ready to send an observation"""
|
|
||||||
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
|
|
||||||
|
|
||||||
def control_loop_observation(self, get_observation_fn):
|
|
||||||
try:
|
|
||||||
# Get serialized observation bytes from the function
|
|
||||||
start_time = time.time()
|
|
||||||
observation = get_observation_fn()
|
|
||||||
obs_capture_time = time.time() - start_time
|
|
||||||
|
|
||||||
# If there are no actions left in the queue, the observation must go through processing!
|
|
||||||
observation.must_go = self.must_go and self.action_queue.empty()
|
|
||||||
self.logger.debug(f"QUEUE SIZE: {self.action_queue.qsize()} (Must go: {observation.must_go})")
|
|
||||||
if observation.must_go:
|
|
||||||
# must-go flag will be set again after receiving actions
|
|
||||||
self.must_go = False
|
|
||||||
|
|
||||||
if not hasattr(self, "last_obs_timestamp"):
|
|
||||||
self.last_obs_timestamp = observation.get_timestamp()
|
|
||||||
|
|
||||||
obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp()
|
|
||||||
self.last_obs_timestamp = obs_timestamp
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"Ts={obs_timestamp} | "
|
|
||||||
f"Captured observation #{obs_timestep} | "
|
|
||||||
f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s")
|
|
||||||
|
|
||||||
# Set appropriate transfer state
|
|
||||||
if obs_timestep == 0:
|
|
||||||
state = async_inference_pb2.TRANSFER_BEGIN
|
|
||||||
else:
|
|
||||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
|
||||||
|
|
||||||
self.send_observation(observation, state)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error in observation sender: {e}")
|
|
||||||
|
|
||||||
def control_loop(self, get_observation_fn):
|
|
||||||
"""Combined function for executing actions and streaming observations"""
|
|
||||||
# Wait at barrier for synchronized start
|
|
||||||
self.start_barrier.wait()
|
|
||||||
self.logger.info("Control loop thread starting")
|
|
||||||
|
|
||||||
control_loops = 0
|
|
||||||
while self.running:
|
|
||||||
control_loop_start = time.time()
|
|
||||||
self.control_loop_action()
|
|
||||||
|
|
||||||
"""Control loop: (2) Streaming observations to the remote policy server"""
|
|
||||||
if self._ready_to_send_observation() or control_loops == 0:
|
|
||||||
self.control_loop_observation(get_observation_fn)
|
|
||||||
|
|
||||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
|
||||||
time.sleep(max(0, environment_dt - (time.time() - control_loop_start)))
|
|
||||||
control_loops += 1
|
|
||||||
|
|
||||||
|
|
||||||
def async_client(task_instruction: str, verbose: int = 0):
|
|
||||||
client = RobotClient()
|
|
||||||
|
|
||||||
if client.start():
|
|
||||||
# Function to get observations from the robot
|
|
||||||
def get_observation():
|
|
||||||
observation_content = None
|
|
||||||
observation_content = client.robot.capture_observation()
|
|
||||||
|
|
||||||
observation_content["task"] = [task_instruction]
|
|
||||||
|
|
||||||
observation = TimedObservation(
|
|
||||||
timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
return observation
|
|
||||||
|
|
||||||
client.logger.info("Starting all threads...")
|
|
||||||
|
|
||||||
# Create and start action receiver thread
|
|
||||||
action_receiver_thread = threading.Thread(target=client.receive_actions)
|
|
||||||
action_receiver_thread.daemon = True
|
|
||||||
|
|
||||||
control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,))
|
|
||||||
control_loop_thread.daemon = True
|
|
||||||
|
|
||||||
# Start all threads
|
|
||||||
action_receiver_thread.start()
|
|
||||||
control_loop_thread.start()
|
|
||||||
|
|
||||||
try:
|
|
||||||
while client.running:
|
|
||||||
time.sleep(idle_wait)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|
||||||
finally:
|
|
||||||
client.stop()
|
|
||||||
client.logger.info("Client stopped")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Robot client for executing tasks via policy server")
|
|
||||||
parser.add_argument(
|
|
||||||
"--task",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Task instruction for the robot to execute (e.g., 'fold my tshirt')",
|
|
||||||
)
|
|
||||||
parser.add_argument("--verbose", type=int, default=0, help="Verbosity level (default: 0)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--server-port-address",
|
|
||||||
type=str,
|
|
||||||
default="localhost:8080",
|
|
||||||
help="Server & port address (default: localhost:8080, or SERVER_ADDRESS env var)",
|
|
||||||
)
|
|
||||||
parser.add_argument("--policy-type", type=str, default="smolvla", help="Policy type (default: smolvla)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained-name-or-path",
|
|
||||||
type=str,
|
|
||||||
default="lerobot/smolvla_base",
|
|
||||||
help="Pretrained model name or path (default: lerobot/smolvla_base)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--policy-device", type=str, default="cuda", help="Device for policy inference (default: cuda)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--chunk-size-threshold",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="Chunk size threshold (`g` in the paper, default: 0.5)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--robot",
|
|
||||||
type=str,
|
|
||||||
default="so100",
|
|
||||||
help="Robot name, as per the `make_robot` function (default: so100)",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Create client with parsed arguments
|
|
||||||
client = RobotClient(
|
|
||||||
server_address=args.server_address,
|
|
||||||
policy_type=args.policy_type,
|
|
||||||
pretrained_name_or_path=args.pretrained_name_or_path,
|
|
||||||
policy_device=args.policy_device,
|
|
||||||
chunk_size_threshold=args.chunk_size_threshold,
|
|
||||||
robot=args.robot,
|
|
||||||
)
|
|
||||||
|
|
||||||
if client.start():
|
|
||||||
# Function to get observations from the robot
|
|
||||||
def get_observation():
|
|
||||||
observation_content = None
|
|
||||||
observation_content = client.robot.capture_observation()
|
|
||||||
|
|
||||||
observation_content["task"] = [args.task]
|
|
||||||
|
|
||||||
observation = TimedObservation(
|
|
||||||
timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
return observation
|
|
||||||
|
|
||||||
client.logger.info("Starting all threads...")
|
|
||||||
|
|
||||||
# Create and start action receiver thread
|
|
||||||
action_receiver_thread = threading.Thread(target=client.receive_actions)
|
|
||||||
action_receiver_thread.daemon = True
|
|
||||||
|
|
||||||
control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,))
|
|
||||||
control_loop_thread.daemon = True
|
|
||||||
|
|
||||||
# Start all threads
|
|
||||||
action_receiver_thread.start()
|
|
||||||
control_loop_thread.start()
|
|
||||||
|
|
||||||
try:
|
|
||||||
while client.running:
|
|
||||||
time.sleep(idle_wait)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|
||||||
finally:
|
|
||||||
client.stop()
|
|
||||||
client.logger.info("Client stopped")
|
|
||||||
@@ -63,7 +63,7 @@ dependencies = [
|
|||||||
"opencv-python-headless>=4.9.0",
|
"opencv-python-headless>=4.9.0",
|
||||||
"packaging>=24.2",
|
"packaging>=24.2",
|
||||||
"av>=14.2.0",
|
"av>=14.2.0",
|
||||||
"pymunk>=6.6.0",
|
"pymunk>=6.6.0,<7.0.0",
|
||||||
"pynput>=1.7.7",
|
"pynput>=1.7.7",
|
||||||
"pyzmq>=26.2.1",
|
"pyzmq>=26.2.1",
|
||||||
"rerun-sdk>=0.21.0",
|
"rerun-sdk>=0.21.0",
|
||||||
|
|||||||
Reference in New Issue
Block a user