multi-node openpi commit

This commit is contained in:
Leon998
2026-03-17 23:05:23 +08:00
parent 28833f0c0f
commit 7411e0e004
156 changed files with 33951 additions and 1 deletions

View File

@@ -0,0 +1,32 @@
# Dockerfile for the simple client.
# Build the container:
# docker build . -t simple_client -f examples/simple_client/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
FROM python:3.7-slim
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"

View File

@@ -0,0 +1,30 @@
# Simple Client
A minimal client that sends observations to the server and prints the inference rate.
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
```bash
uv run examples/simple_client/main.py --help
```
## With Docker
```bash
export SERVER_ARGS="--env ALOHA_SIM"
docker compose -f examples/simple_client/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
uv run examples/simple_client/main.py --env DROID
```
Terminal window 2:
```bash
uv run scripts/serve_policy.py --env DROID
```

View File

@@ -0,0 +1,42 @@
# Run with:
# docker compose -f examples/simple_client/compose.yml up --build
services:
runtime:
image: simple_client
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/simple_client/Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
environment:
- SERVER_ARGS
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,187 @@
import dataclasses
import enum
import logging
import pathlib
import time
import numpy as np
from openpi_client import websocket_client_policy as _websocket_client_policy
import polars as pl
import rich
import tqdm
import tyro
logger = logging.getLogger(__name__)
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
LIBERO = "libero"
@dataclasses.dataclass
class Args:
"""Command line arguments."""
# Host and port to connect to the server.
host: str = "0.0.0.0"
# Port to connect to the server. If None, the server will use the default port.
port: int | None = 8000
# API key to use for the server.
api_key: str | None = None
# Number of steps to run the policy for.
num_steps: int = 20
# Path to save the timings to a parquet file. (e.g., timing.parquet)
timing_file: pathlib.Path | None = None
# Environment to run the policy in.
env: EnvMode = EnvMode.ALOHA_SIM
class TimingRecorder:
"""Records timing measurements for different keys."""
def __init__(self) -> None:
self._timings: dict[str, list[float]] = {}
def record(self, key: str, time_ms: float) -> None:
"""Record a timing measurement for the given key."""
if key not in self._timings:
self._timings[key] = []
self._timings[key].append(time_ms)
def get_stats(self, key: str) -> dict[str, float]:
"""Get statistics for the given key."""
times = self._timings[key]
return {
"mean": float(np.mean(times)),
"std": float(np.std(times)),
"p25": float(np.quantile(times, 0.25)),
"p50": float(np.quantile(times, 0.50)),
"p75": float(np.quantile(times, 0.75)),
"p90": float(np.quantile(times, 0.90)),
"p95": float(np.quantile(times, 0.95)),
"p99": float(np.quantile(times, 0.99)),
}
def print_all_stats(self) -> None:
"""Print statistics for all keys in a concise format."""
table = rich.table.Table(
title="[bold blue]Timing Statistics[/bold blue]",
show_header=True,
header_style="bold white",
border_style="blue",
title_justify="center",
)
# Add metric column with custom styling
table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
# Add statistical columns with consistent styling
stat_columns = [
("Mean", "yellow", "mean"),
("Std", "yellow", "std"),
("P25", "magenta", "p25"),
("P50", "magenta", "p50"),
("P75", "magenta", "p75"),
("P90", "magenta", "p90"),
("P95", "magenta", "p95"),
("P99", "magenta", "p99"),
]
for name, style, _ in stat_columns:
table.add_column(name, justify="right", style=style, no_wrap=True)
# Add rows for each metric with formatted values
for key in sorted(self._timings.keys()):
stats = self.get_stats(key)
values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
table.add_row(key, *values)
# Print with custom console settings
console = rich.console.Console(width=None, highlight=True)
console.print(table)
def write_parquet(self, path: pathlib.Path) -> None:
"""Save the timings to a parquet file."""
logger.info(f"Writing timings to {path}")
frame = pl.DataFrame(self._timings)
path.parent.mkdir(parents=True, exist_ok=True)
frame.write_parquet(path)
def main(args: Args) -> None:
obs_fn = {
EnvMode.ALOHA: _random_observation_aloha,
EnvMode.ALOHA_SIM: _random_observation_aloha,
EnvMode.DROID: _random_observation_droid,
EnvMode.LIBERO: _random_observation_libero,
}[args.env]
policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
api_key=args.api_key,
)
logger.info(f"Server metadata: {policy.get_server_metadata()}")
# Send a few observations to make sure the model is loaded.
for _ in range(2):
policy.infer(obs_fn())
timing_recorder = TimingRecorder()
for _ in tqdm.trange(args.num_steps, desc="Running policy"):
inference_start = time.time()
action = policy.infer(obs_fn())
timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
for key, value in action.get("server_timing", {}).items():
timing_recorder.record(f"server_{key}", value)
for key, value in action.get("policy_timing", {}).items():
timing_recorder.record(f"policy_{key}", value)
timing_recorder.print_all_stats()
if args.timing_file is not None:
timing_recorder.write_parquet(args.timing_file)
def _random_observation_aloha() -> dict:
return {
"state": np.ones((14,)),
"images": {
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
},
"prompt": "do something",
}
def _random_observation_droid() -> dict:
return {
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/joint_position": np.random.rand(7),
"observation/gripper_position": np.random.rand(1),
"prompt": "do something",
}
def _random_observation_libero() -> dict:
return {
"observation/state": np.random.rand(8),
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"prompt": "do something",
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main(tyro.cli(Args))

View File

@@ -0,0 +1,5 @@
numpy>=1.22.4,<2.0.0
rich
tqdm
tyro
polars

View File

@@ -0,0 +1,30 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9
docstring-parser==0.16
# via tyro
markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
numpy==1.26.4
# via -r examples/simple_client/requirements.in
polars==1.30.0
# via -r examples/simple_client/requirements.in
pygments==2.19.1
# via rich
rich==14.0.0
# via
# -r examples/simple_client/requirements.in
# tyro
shtab==1.7.2
# via tyro
tqdm==4.67.1
# via -r examples/simple_client/requirements.in
typeguard==4.4.2
# via tyro
typing-extensions==4.13.2
# via
# typeguard
# tyro
tyro==0.9.22
# via -r examples/simple_client/requirements.in