Initial commit
This commit is contained in:
25
packages/openpi-client/pyproject.toml
Normal file
25
packages/openpi-client/pyproject.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[project]
|
||||
name = "openpi-client"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.7"
|
||||
dependencies = [
|
||||
"dm-tree>=0.1.8",
|
||||
"msgpack>=1.0.5",
|
||||
"numpy>=1.21.6",
|
||||
"pillow>=9.0.0",
|
||||
"tree>=0.2.4",
|
||||
"websockets>=11.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"pytest>=8.3.4",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py37"
|
||||
1
packages/openpi-client/src/openpi_client/__init__.py
Normal file
1
packages/openpi-client/src/openpi_client/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.1.0"
|
||||
@@ -0,0 +1,39 @@
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import tree
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
|
||||
|
||||
class ActionChunkBroker(_base_policy.BasePolicy):
|
||||
"""Wraps a policy to return action chunks one-at-a-time.
|
||||
|
||||
Assumes that the first dimension of all action fields is the chunk size.
|
||||
|
||||
A new inference call to the inner policy is only made when the current
|
||||
list of chunks is exhausted.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
|
||||
self._policy = policy
|
||||
|
||||
self._action_horizon = action_horizon
|
||||
self._cur_step: int = 0
|
||||
|
||||
self._last_results: Dict[str, np.ndarray] | None = None
|
||||
|
||||
@override
|
||||
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
||||
if self._last_results is None:
|
||||
self._last_results = self._policy.infer(obs)
|
||||
self._cur_step = 0
|
||||
|
||||
results = tree.map_structure(lambda x: x[self._cur_step, ...], self._last_results)
|
||||
self._cur_step += 1
|
||||
|
||||
if self._cur_step >= self._action_horizon:
|
||||
self._last_results = None
|
||||
|
||||
return results
|
||||
8
packages/openpi-client/src/openpi_client/base_policy.py
Normal file
8
packages/openpi-client/src/openpi_client/base_policy.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import abc
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class BasePolicy(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def infer(self, obs: Dict) -> Dict:
|
||||
"""Infer actions from observations."""
|
||||
48
packages/openpi-client/src/openpi_client/image_tools.py
Normal file
48
packages/openpi-client/src/openpi_client/image_tools.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
|
||||
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
|
||||
|
||||
Args:
|
||||
images: A batch of images in [..., height, width, channel] format.
|
||||
height: The target height of the image.
|
||||
width: The target width of the image.
|
||||
method: The interpolation method to use. Default is bilinear.
|
||||
|
||||
Returns:
|
||||
The resized images in [..., height, width, channel].
|
||||
"""
|
||||
# If the images are already the correct size, return them as is.
|
||||
if images.shape[-3:-1] == (height, width):
|
||||
return images
|
||||
|
||||
original_shape = images.shape
|
||||
|
||||
images = images.reshape(-1, *original_shape[-3:])
|
||||
resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
|
||||
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
|
||||
|
||||
|
||||
def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
|
||||
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
|
||||
width without distortion by padding with zeros.
|
||||
|
||||
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
|
||||
"""
|
||||
cur_width, cur_height = image.size
|
||||
if cur_width == width and cur_height == height:
|
||||
return image # No need to resize if the image is already the correct size.
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_image = image.resize((resized_width, resized_height), resample=method)
|
||||
|
||||
zero_image = Image.new(resized_image.mode, (width, height), 0)
|
||||
pad_height = max(0, int((height - resized_height) / 2))
|
||||
pad_width = max(0, int((width - resized_width) / 2))
|
||||
zero_image.paste(resized_image, (pad_width, pad_height))
|
||||
assert zero_image.size == (width, height)
|
||||
return zero_image
|
||||
37
packages/openpi-client/src/openpi_client/image_tools_test.py
Normal file
37
packages/openpi-client/src/openpi_client/image_tools_test.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
|
||||
import openpi_client.image_tools as image_tools
|
||||
|
||||
|
||||
def test_resize_with_pad_shapes():
|
||||
# Test case 1: Resize image with larger dimensions
|
||||
images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
|
||||
height = 20
|
||||
width = 20
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (2, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 2: Resize image with smaller dimensions
|
||||
images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
|
||||
height = 15
|
||||
width = 15
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (3, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with the same dimensions
|
||||
images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
|
||||
height = 50
|
||||
width = 50
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with odd-numbered padding
|
||||
images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
|
||||
height = 60
|
||||
width = 80
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
57
packages/openpi-client/src/openpi_client/msgpack_numpy.py
Normal file
57
packages/openpi-client/src/openpi_client/msgpack_numpy.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Adds NumPy array support to msgpack.
|
||||
|
||||
msgpack is good for (de)serializing data over a network for multiple reasons:
|
||||
- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
|
||||
- msgpack is widely used and has good cross-language support
|
||||
- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
|
||||
languages like Python and JavaScript
|
||||
- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
|
||||
than pickle for serializing large arrays using the below strategy
|
||||
|
||||
The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
|
||||
that it falls back to pickle for object arrays.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pack_array(obj):
|
||||
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
|
||||
raise ValueError(f"Unsupported dtype: {obj.dtype}")
|
||||
|
||||
if isinstance(obj, np.ndarray):
|
||||
return {
|
||||
b"__ndarray__": True,
|
||||
b"data": obj.tobytes(),
|
||||
b"dtype": obj.dtype.str,
|
||||
b"shape": obj.shape,
|
||||
}
|
||||
|
||||
if isinstance(obj, np.generic):
|
||||
return {
|
||||
b"__npgeneric__": True,
|
||||
b"data": obj.item(),
|
||||
b"dtype": obj.dtype.str,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def unpack_array(obj):
|
||||
if b"__ndarray__" in obj:
|
||||
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
|
||||
|
||||
if b"__npgeneric__" in obj:
|
||||
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
Packer = functools.partial(msgpack.Packer, default=pack_array)
|
||||
packb = functools.partial(msgpack.packb, default=pack_array)
|
||||
|
||||
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
|
||||
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
|
||||
@@ -0,0 +1,45 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tree
|
||||
|
||||
from openpi_client import msgpack_numpy
|
||||
|
||||
|
||||
def _check(expected, actual):
|
||||
if isinstance(expected, np.ndarray):
|
||||
assert expected.shape == actual.shape
|
||||
assert expected.dtype == actual.dtype
|
||||
assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
|
||||
else:
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data",
|
||||
[
|
||||
1, # int
|
||||
1.0, # float
|
||||
"hello", # string
|
||||
np.bool_(True), # boolean scalar
|
||||
np.array([1, 2, 3])[0], # int scalar
|
||||
np.str_("asdf"), # string scalar
|
||||
[1, 2, 3], # list
|
||||
{"key": "value"}, # dict
|
||||
{"key": [1, 2, 3]}, # nested dict
|
||||
np.array(1.0), # 0D array
|
||||
np.array([1, 2, 3], dtype=np.int32), # 1D integer array
|
||||
np.array(["asdf", "qwer"]), # string array
|
||||
np.array([True, False]), # boolean array
|
||||
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
|
||||
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
|
||||
np.array([np.nan, np.inf, -np.inf]), # special float values
|
||||
{"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
|
||||
[np.array([1, 2]), np.array([3, 4])], # list of arrays
|
||||
np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
|
||||
np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
|
||||
],
|
||||
)
|
||||
def test_pack_unpack(data):
|
||||
packed = msgpack_numpy.packb(data)
|
||||
unpacked = msgpack_numpy.unpackb(packed)
|
||||
tree.map_structure(_check, data, unpacked)
|
||||
13
packages/openpi-client/src/openpi_client/runtime/agent.py
Normal file
13
packages/openpi-client/src/openpi_client/runtime/agent.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Agent(abc.ABC):
|
||||
"""An Agent is the thing with agency, i.e. the entity that makes decisions.
|
||||
|
||||
Agents receive observations about the state of the world, and return actions
|
||||
to take in response.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_action(self, observation: dict) -> dict:
|
||||
"""Query the agent for the next action."""
|
||||
@@ -0,0 +1,15 @@
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client.runtime import agent as _agent
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# TODO: Consider unifying policies and agents.
|
||||
class PolicyAgent(_agent.Agent):
|
||||
"""An agent that uses a policy to determine actions."""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy) -> None:
|
||||
self._policy = policy
|
||||
|
||||
@override
|
||||
def get_action(self, observation: dict) -> dict:
|
||||
return self._policy.infer(observation)
|
||||
@@ -0,0 +1,32 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Environment(abc.ABC):
|
||||
"""An Environment represents the robot and the environment it inhabits.
|
||||
|
||||
The primary contract of environments is that they can be queried for observations
|
||||
about their state, and have actions applied to them to change that state.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the environment to its initial state.
|
||||
|
||||
This will be called once before starting each episode.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def done(self) -> bool:
|
||||
"""Allow the environment to signal that the task is done.
|
||||
|
||||
This will be called after each step. It should return `True` if the task is
|
||||
done (either successfully or unsuccessfully), and `False` otherwise.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_observation(self) -> dict:
|
||||
"""Query the environment for the current state."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def apply_action(self, action: dict) -> None:
|
||||
"""Take an action in the environment."""
|
||||
78
packages/openpi-client/src/openpi_client/runtime/runtime.py
Normal file
78
packages/openpi-client/src/openpi_client/runtime/runtime.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from openpi_client.runtime import agent as _agent
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""The core module orchestrating interactions between key components of the system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
environment: _environment.Environment,
|
||||
agent: _agent.Agent,
|
||||
subscribers: list[_subscriber.Subscriber],
|
||||
max_hz: float = 0,
|
||||
) -> None:
|
||||
self._environment = environment
|
||||
self._agent = agent
|
||||
self._subscribers = subscribers
|
||||
self._max_hz = max_hz
|
||||
|
||||
self._running = False
|
||||
|
||||
def run(self) -> None:
|
||||
"""Runs the runtime loop continuously until stop() is called or the environment is done."""
|
||||
self._loop()
|
||||
|
||||
def run_in_new_thread(self) -> threading.Thread:
|
||||
"""Runs the runtime loop in a new thread."""
|
||||
thread = threading.Thread(target=self.run)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the runtime loop."""
|
||||
self._running = False
|
||||
|
||||
def _loop(self) -> None:
|
||||
"""The runtime loop."""
|
||||
logging.info("Starting episode...")
|
||||
self._environment.reset()
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_episode_start()
|
||||
|
||||
self._running = True
|
||||
step_time = 1 / self._max_hz if self._max_hz > 0 else 0
|
||||
last_step_time = time.time()
|
||||
|
||||
while self._running:
|
||||
self._step()
|
||||
|
||||
# Sleep to maintain the desired frame rate
|
||||
now = time.time()
|
||||
dt = now - last_step_time
|
||||
if dt < step_time:
|
||||
time.sleep(step_time - dt)
|
||||
last_step_time = time.time()
|
||||
else:
|
||||
last_step_time = now
|
||||
|
||||
logging.info("Episode completed.")
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_episode_end()
|
||||
|
||||
def _step(self) -> None:
|
||||
"""A single step of the runtime loop."""
|
||||
observation = self._environment.get_observation()
|
||||
action = self._agent.get_action(observation)
|
||||
self._environment.apply_action(action)
|
||||
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_step(observation, action)
|
||||
|
||||
if self._environment.done():
|
||||
self.stop()
|
||||
@@ -0,0 +1,20 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Subscriber(abc.ABC):
|
||||
"""Subscribes to events in the runtime.
|
||||
|
||||
Subscribers can be used to save data, visualize, etc.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_episode_start(self) -> None:
|
||||
"""Called when an episode starts."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
"""Append a step to the episode."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_episode_end(self) -> None:
|
||||
"""Called when an episode ends."""
|
||||
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from typing_extensions import override
|
||||
import websockets.sync.client
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client import msgpack_numpy
|
||||
|
||||
|
||||
class WebsocketClientPolicy(_base_policy.BasePolicy):
|
||||
"""Implements the Policy interface by communicating with a server over websocket.
|
||||
|
||||
See WebsocketPolicyServer for a corresponding server implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8000) -> None:
|
||||
self._uri = f"ws://{host}:{port}"
|
||||
self._packer = msgpack_numpy.Packer()
|
||||
self._ws = self._wait_for_server()
|
||||
|
||||
def _wait_for_server(self) -> websockets.sync.client.ClientConnection:
|
||||
logging.info(f"Waiting for server at {self._uri}...")
|
||||
while True:
|
||||
try:
|
||||
return websockets.sync.client.connect(self._uri, compression=None, max_size=None)
|
||||
except ConnectionRefusedError:
|
||||
logging.info("Still waiting for server...")
|
||||
time.sleep(5)
|
||||
|
||||
@override
|
||||
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
||||
data = self._packer.pack(obs)
|
||||
self._ws.send(data)
|
||||
response = self._ws.recv()
|
||||
if isinstance(response, str):
|
||||
# we're expecting bytes; if the server sends a string, it's an error.
|
||||
raise RuntimeError(f"Error in inference server:\n{response}")
|
||||
return msgpack_numpy.unpackb(response)
|
||||
Reference in New Issue
Block a user