multi-node openpi commit
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
# Dockerfile for the simple client.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t simple_client -f examples/simple_client/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
|
||||
|
||||
FROM python:3.7-slim
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"
|
||||
30
policy/openpi-InternData-A1/examples/simple_client/README.md
Normal file
30
policy/openpi-InternData-A1/examples/simple_client/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Simple Client
|
||||
|
||||
A minimal client that sends observations to the server and prints the inference rate.
|
||||
|
||||
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --help
|
||||
```
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/simple_client/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --env DROID
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env DROID
|
||||
```
|
||||
@@ -0,0 +1,42 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/simple_client/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: simple_client
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/simple_client/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
187
policy/openpi-InternData-A1/examples/simple_client/main.py
Normal file
187
policy/openpi-InternData-A1/examples/simple_client/main.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import pathlib
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import polars as pl
|
||||
import rich
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
LIBERO = "libero"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
"""Command line arguments."""
|
||||
|
||||
# Host and port to connect to the server.
|
||||
host: str = "0.0.0.0"
|
||||
# Port to connect to the server. If None, the server will use the default port.
|
||||
port: int | None = 8000
|
||||
# API key to use for the server.
|
||||
api_key: str | None = None
|
||||
# Number of steps to run the policy for.
|
||||
num_steps: int = 20
|
||||
# Path to save the timings to a parquet file. (e.g., timing.parquet)
|
||||
timing_file: pathlib.Path | None = None
|
||||
# Environment to run the policy in.
|
||||
env: EnvMode = EnvMode.ALOHA_SIM
|
||||
|
||||
|
||||
class TimingRecorder:
|
||||
"""Records timing measurements for different keys."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._timings: dict[str, list[float]] = {}
|
||||
|
||||
def record(self, key: str, time_ms: float) -> None:
|
||||
"""Record a timing measurement for the given key."""
|
||||
if key not in self._timings:
|
||||
self._timings[key] = []
|
||||
self._timings[key].append(time_ms)
|
||||
|
||||
def get_stats(self, key: str) -> dict[str, float]:
|
||||
"""Get statistics for the given key."""
|
||||
times = self._timings[key]
|
||||
return {
|
||||
"mean": float(np.mean(times)),
|
||||
"std": float(np.std(times)),
|
||||
"p25": float(np.quantile(times, 0.25)),
|
||||
"p50": float(np.quantile(times, 0.50)),
|
||||
"p75": float(np.quantile(times, 0.75)),
|
||||
"p90": float(np.quantile(times, 0.90)),
|
||||
"p95": float(np.quantile(times, 0.95)),
|
||||
"p99": float(np.quantile(times, 0.99)),
|
||||
}
|
||||
|
||||
def print_all_stats(self) -> None:
|
||||
"""Print statistics for all keys in a concise format."""
|
||||
|
||||
table = rich.table.Table(
|
||||
title="[bold blue]Timing Statistics[/bold blue]",
|
||||
show_header=True,
|
||||
header_style="bold white",
|
||||
border_style="blue",
|
||||
title_justify="center",
|
||||
)
|
||||
|
||||
# Add metric column with custom styling
|
||||
table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
|
||||
|
||||
# Add statistical columns with consistent styling
|
||||
stat_columns = [
|
||||
("Mean", "yellow", "mean"),
|
||||
("Std", "yellow", "std"),
|
||||
("P25", "magenta", "p25"),
|
||||
("P50", "magenta", "p50"),
|
||||
("P75", "magenta", "p75"),
|
||||
("P90", "magenta", "p90"),
|
||||
("P95", "magenta", "p95"),
|
||||
("P99", "magenta", "p99"),
|
||||
]
|
||||
|
||||
for name, style, _ in stat_columns:
|
||||
table.add_column(name, justify="right", style=style, no_wrap=True)
|
||||
|
||||
# Add rows for each metric with formatted values
|
||||
for key in sorted(self._timings.keys()):
|
||||
stats = self.get_stats(key)
|
||||
values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
|
||||
table.add_row(key, *values)
|
||||
|
||||
# Print with custom console settings
|
||||
console = rich.console.Console(width=None, highlight=True)
|
||||
console.print(table)
|
||||
|
||||
def write_parquet(self, path: pathlib.Path) -> None:
|
||||
"""Save the timings to a parquet file."""
|
||||
logger.info(f"Writing timings to {path}")
|
||||
frame = pl.DataFrame(self._timings)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
frame.write_parquet(path)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
obs_fn = {
|
||||
EnvMode.ALOHA: _random_observation_aloha,
|
||||
EnvMode.ALOHA_SIM: _random_observation_aloha,
|
||||
EnvMode.DROID: _random_observation_droid,
|
||||
EnvMode.LIBERO: _random_observation_libero,
|
||||
}[args.env]
|
||||
|
||||
policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
api_key=args.api_key,
|
||||
)
|
||||
logger.info(f"Server metadata: {policy.get_server_metadata()}")
|
||||
|
||||
# Send a few observations to make sure the model is loaded.
|
||||
for _ in range(2):
|
||||
policy.infer(obs_fn())
|
||||
|
||||
timing_recorder = TimingRecorder()
|
||||
|
||||
for _ in tqdm.trange(args.num_steps, desc="Running policy"):
|
||||
inference_start = time.time()
|
||||
action = policy.infer(obs_fn())
|
||||
timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
|
||||
for key, value in action.get("server_timing", {}).items():
|
||||
timing_recorder.record(f"server_{key}", value)
|
||||
for key, value in action.get("policy_timing", {}).items():
|
||||
timing_recorder.record(f"policy_{key}", value)
|
||||
|
||||
timing_recorder.print_all_stats()
|
||||
|
||||
if args.timing_file is not None:
|
||||
timing_recorder.write_parquet(args.timing_file)
|
||||
|
||||
|
||||
def _random_observation_aloha() -> dict:
|
||||
return {
|
||||
"state": np.ones((14,)),
|
||||
"images": {
|
||||
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
},
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_droid() -> dict:
|
||||
return {
|
||||
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/joint_position": np.random.rand(7),
|
||||
"observation/gripper_position": np.random.rand(1),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_libero() -> dict:
|
||||
return {
|
||||
"observation/state": np.random.rand(8),
|
||||
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main(tyro.cli(Args))
|
||||
@@ -0,0 +1,5 @@
|
||||
numpy>=1.22.4,<2.0.0
|
||||
rich
|
||||
tqdm
|
||||
tyro
|
||||
polars
|
||||
@@ -0,0 +1,30 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
numpy==1.26.4
|
||||
# via -r examples/simple_client/requirements.in
|
||||
polars==1.30.0
|
||||
# via -r examples/simple_client/requirements.in
|
||||
pygments==2.19.1
|
||||
# via rich
|
||||
rich==14.0.0
|
||||
# via
|
||||
# -r examples/simple_client/requirements.in
|
||||
# tyro
|
||||
shtab==1.7.2
|
||||
# via tyro
|
||||
tqdm==4.67.1
|
||||
# via -r examples/simple_client/requirements.in
|
||||
typeguard==4.4.2
|
||||
# via tyro
|
||||
typing-extensions==4.13.2
|
||||
# via
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.22
|
||||
# via -r examples/simple_client/requirements.in
|
||||
Reference in New Issue
Block a user