188 lines
6.1 KiB
Python
188 lines
6.1 KiB
Python
import dataclasses
|
|
import enum
|
|
import logging
|
|
import pathlib
|
|
import time
|
|
|
|
import numpy as np
|
|
from openpi_client import websocket_client_policy as _websocket_client_policy
|
|
import polars as pl
|
|
import rich
|
|
import tqdm
|
|
import tyro
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EnvMode(enum.Enum):
|
|
"""Supported environments."""
|
|
|
|
ALOHA = "aloha"
|
|
ALOHA_SIM = "aloha_sim"
|
|
DROID = "droid"
|
|
LIBERO = "libero"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Args:
|
|
"""Command line arguments."""
|
|
|
|
# Host and port to connect to the server.
|
|
host: str = "0.0.0.0"
|
|
# Port to connect to the server. If None, the server will use the default port.
|
|
port: int | None = 8000
|
|
# API key to use for the server.
|
|
api_key: str | None = None
|
|
# Number of steps to run the policy for.
|
|
num_steps: int = 20
|
|
# Path to save the timings to a parquet file. (e.g., timing.parquet)
|
|
timing_file: pathlib.Path | None = None
|
|
# Environment to run the policy in.
|
|
env: EnvMode = EnvMode.ALOHA_SIM
|
|
|
|
|
|
class TimingRecorder:
|
|
"""Records timing measurements for different keys."""
|
|
|
|
def __init__(self) -> None:
|
|
self._timings: dict[str, list[float]] = {}
|
|
|
|
def record(self, key: str, time_ms: float) -> None:
|
|
"""Record a timing measurement for the given key."""
|
|
if key not in self._timings:
|
|
self._timings[key] = []
|
|
self._timings[key].append(time_ms)
|
|
|
|
def get_stats(self, key: str) -> dict[str, float]:
|
|
"""Get statistics for the given key."""
|
|
times = self._timings[key]
|
|
return {
|
|
"mean": float(np.mean(times)),
|
|
"std": float(np.std(times)),
|
|
"p25": float(np.quantile(times, 0.25)),
|
|
"p50": float(np.quantile(times, 0.50)),
|
|
"p75": float(np.quantile(times, 0.75)),
|
|
"p90": float(np.quantile(times, 0.90)),
|
|
"p95": float(np.quantile(times, 0.95)),
|
|
"p99": float(np.quantile(times, 0.99)),
|
|
}
|
|
|
|
def print_all_stats(self) -> None:
|
|
"""Print statistics for all keys in a concise format."""
|
|
|
|
table = rich.table.Table(
|
|
title="[bold blue]Timing Statistics[/bold blue]",
|
|
show_header=True,
|
|
header_style="bold white",
|
|
border_style="blue",
|
|
title_justify="center",
|
|
)
|
|
|
|
# Add metric column with custom styling
|
|
table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
|
|
|
|
# Add statistical columns with consistent styling
|
|
stat_columns = [
|
|
("Mean", "yellow", "mean"),
|
|
("Std", "yellow", "std"),
|
|
("P25", "magenta", "p25"),
|
|
("P50", "magenta", "p50"),
|
|
("P75", "magenta", "p75"),
|
|
("P90", "magenta", "p90"),
|
|
("P95", "magenta", "p95"),
|
|
("P99", "magenta", "p99"),
|
|
]
|
|
|
|
for name, style, _ in stat_columns:
|
|
table.add_column(name, justify="right", style=style, no_wrap=True)
|
|
|
|
# Add rows for each metric with formatted values
|
|
for key in sorted(self._timings.keys()):
|
|
stats = self.get_stats(key)
|
|
values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
|
|
table.add_row(key, *values)
|
|
|
|
# Print with custom console settings
|
|
console = rich.console.Console(width=None, highlight=True)
|
|
console.print(table)
|
|
|
|
def write_parquet(self, path: pathlib.Path) -> None:
|
|
"""Save the timings to a parquet file."""
|
|
logger.info(f"Writing timings to {path}")
|
|
frame = pl.DataFrame(self._timings)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
frame.write_parquet(path)
|
|
|
|
|
|
def main(args: Args) -> None:
|
|
obs_fn = {
|
|
EnvMode.ALOHA: _random_observation_aloha,
|
|
EnvMode.ALOHA_SIM: _random_observation_aloha,
|
|
EnvMode.DROID: _random_observation_droid,
|
|
EnvMode.LIBERO: _random_observation_libero,
|
|
}[args.env]
|
|
|
|
policy = _websocket_client_policy.WebsocketClientPolicy(
|
|
host=args.host,
|
|
port=args.port,
|
|
api_key=args.api_key,
|
|
)
|
|
logger.info(f"Server metadata: {policy.get_server_metadata()}")
|
|
|
|
# Send a few observations to make sure the model is loaded.
|
|
for _ in range(2):
|
|
policy.infer(obs_fn())
|
|
|
|
timing_recorder = TimingRecorder()
|
|
|
|
for _ in tqdm.trange(args.num_steps, desc="Running policy"):
|
|
inference_start = time.time()
|
|
action = policy.infer(obs_fn())
|
|
timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
|
|
for key, value in action.get("server_timing", {}).items():
|
|
timing_recorder.record(f"server_{key}", value)
|
|
for key, value in action.get("policy_timing", {}).items():
|
|
timing_recorder.record(f"policy_{key}", value)
|
|
|
|
timing_recorder.print_all_stats()
|
|
|
|
if args.timing_file is not None:
|
|
timing_recorder.write_parquet(args.timing_file)
|
|
|
|
|
|
def _random_observation_aloha() -> dict:
|
|
return {
|
|
"state": np.ones((14,)),
|
|
"images": {
|
|
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
|
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
|
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
|
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
|
},
|
|
"prompt": "do something",
|
|
}
|
|
|
|
|
|
def _random_observation_droid() -> dict:
|
|
return {
|
|
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
|
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
|
"observation/joint_position": np.random.rand(7),
|
|
"observation/gripper_position": np.random.rand(1),
|
|
"prompt": "do something",
|
|
}
|
|
|
|
|
|
def _random_observation_libero() -> dict:
|
|
return {
|
|
"observation/state": np.random.rand(8),
|
|
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
|
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
|
"prompt": "do something",
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
main(tyro.cli(Args))
|