Initial commit

This commit is contained in:
Ury Zhilinsky
2025-02-03 21:43:26 -08:00
commit 231a1cf7ca
121 changed files with 16349 additions and 0 deletions

View File

@@ -0,0 +1 @@
__version__ = "0.1.0"

View File

@@ -0,0 +1,45 @@
from typing import Dict
import numpy as np
import tree
from typing_extensions import override
from openpi_client import base_policy as _base_policy
class ActionChunkBroker(_base_policy.BasePolicy):
"""Wraps a policy to return action chunks one-at-a-time.
Assumes that the first dimension of all action fields is the chunk size.
A new inference call to the inner policy is only made when the current
list of chunks is exhausted.
"""
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
self._policy = policy
self._action_horizon = action_horizon
self._cur_step: int = 0
self._last_results: Dict[str, np.ndarray] | None = None
@override
def infer(self, obs: Dict) -> Dict: # noqa: UP006
if self._last_results is None:
self._last_results = self._policy.infer(obs)
self._cur_step = 0
results = tree.map_structure(lambda x: x[self._cur_step, ...], self._last_results)
self._cur_step += 1
if self._cur_step >= self._action_horizon:
self._last_results = None
return results
@override
def reset(self) -> None:
self._policy.reset()
self._last_results = None
self._cur_step = 0

View File

@@ -0,0 +1,12 @@
import abc
from typing import Dict
class BasePolicy(abc.ABC):
@abc.abstractmethod
def infer(self, obs: Dict) -> Dict:
"""Infer actions from observations."""
def reset(self) -> None:
"""Reset the policy to its initial state."""
pass

View File

@@ -0,0 +1,58 @@
import numpy as np
from PIL import Image
def convert_to_uint8(img: np.ndarray) -> np.ndarray:
"""Converts an image to uint8 if it is a float image.
This is important for reducing the size of the image when sending it over the network.
"""
if np.issubdtype(img.dtype, np.floating):
img = (255 * img).astype(np.uint8)
return img
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
Args:
images: A batch of images in [..., height, width, channel] format.
height: The target height of the image.
width: The target width of the image.
method: The interpolation method to use. Default is bilinear.
Returns:
The resized images in [..., height, width, channel].
"""
# If the images are already the correct size, return them as is.
if images.shape[-3:-1] == (height, width):
return images
original_shape = images.shape
images = images.reshape(-1, *original_shape[-3:])
resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
width without distortion by padding with zeros.
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
"""
cur_width, cur_height = image.size
if cur_width == width and cur_height == height:
return image # No need to resize if the image is already the correct size.
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_image = image.resize((resized_width, resized_height), resample=method)
zero_image = Image.new(resized_image.mode, (width, height), 0)
pad_height = max(0, int((height - resized_height) / 2))
pad_width = max(0, int((width - resized_width) / 2))
zero_image.paste(resized_image, (pad_width, pad_height))
assert zero_image.size == (width, height)
return zero_image

View File

@@ -0,0 +1,37 @@
import numpy as np
import openpi_client.image_tools as image_tools
def test_resize_with_pad_shapes():
# Test case 1: Resize image with larger dimensions
images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
height = 20
width = 20
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (2, height, width, 3)
assert np.all(resized_images == 0)
# Test case 2: Resize image with smaller dimensions
images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
height = 15
width = 15
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (3, height, width, 3)
assert np.all(resized_images == 0)
# Test case 3: Resize image with the same dimensions
images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
height = 50
width = 50
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (1, height, width, 3)
assert np.all(resized_images == 0)
# Test case 3: Resize image with odd-numbered padding
images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
height = 60
width = 80
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (1, height, width, 3)
assert np.all(resized_images == 0)

View File

@@ -0,0 +1,57 @@
"""Adds NumPy array support to msgpack.
msgpack is good for (de)serializing data over a network for multiple reasons:
- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
- msgpack is widely used and has good cross-language support
- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
languages like Python and JavaScript
- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
than pickle for serializing large arrays using the below strategy
The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
that it falls back to pickle for object arrays.
"""
import functools
import msgpack
import numpy as np
def pack_array(obj):
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
raise ValueError(f"Unsupported dtype: {obj.dtype}")
if isinstance(obj, np.ndarray):
return {
b"__ndarray__": True,
b"data": obj.tobytes(),
b"dtype": obj.dtype.str,
b"shape": obj.shape,
}
if isinstance(obj, np.generic):
return {
b"__npgeneric__": True,
b"data": obj.item(),
b"dtype": obj.dtype.str,
}
return obj
def unpack_array(obj):
if b"__ndarray__" in obj:
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
if b"__npgeneric__" in obj:
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
return obj
Packer = functools.partial(msgpack.Packer, default=pack_array)
packb = functools.partial(msgpack.packb, default=pack_array)
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)

View File

@@ -0,0 +1,45 @@
import numpy as np
import pytest
import tree
from openpi_client import msgpack_numpy
def _check(expected, actual):
if isinstance(expected, np.ndarray):
assert expected.shape == actual.shape
assert expected.dtype == actual.dtype
assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
else:
assert expected == actual
@pytest.mark.parametrize(
"data",
[
1, # int
1.0, # float
"hello", # string
np.bool_(True), # boolean scalar
np.array([1, 2, 3])[0], # int scalar
np.str_("asdf"), # string scalar
[1, 2, 3], # list
{"key": "value"}, # dict
{"key": [1, 2, 3]}, # nested dict
np.array(1.0), # 0D array
np.array([1, 2, 3], dtype=np.int32), # 1D integer array
np.array(["asdf", "qwer"]), # string array
np.array([True, False]), # boolean array
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
np.array([np.nan, np.inf, -np.inf]), # special float values
{"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
[np.array([1, 2]), np.array([3, 4])], # list of arrays
np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
],
)
def test_pack_unpack(data):
packed = msgpack_numpy.packb(data)
unpacked = msgpack_numpy.unpackb(packed)
tree.map_structure(_check, data, unpacked)

View File

@@ -0,0 +1,17 @@
import abc
class Agent(abc.ABC):
"""An Agent is the thing with agency, i.e. the entity that makes decisions.
Agents receive observations about the state of the world, and return actions
to take in response.
"""
@abc.abstractmethod
def get_action(self, observation: dict) -> dict:
"""Query the agent for the next action."""
@abc.abstractmethod
def reset(self) -> None:
"""Reset the agent to its initial state."""

View File

@@ -0,0 +1,18 @@
from typing_extensions import override
from openpi_client import base_policy as _base_policy
from openpi_client.runtime import agent as _agent
class PolicyAgent(_agent.Agent):
"""An agent that uses a policy to determine actions."""
def __init__(self, policy: _base_policy.BasePolicy) -> None:
self._policy = policy
@override
def get_action(self, observation: dict) -> dict:
return self._policy.infer(observation)
def reset(self) -> None:
self._policy.reset()

View File

@@ -0,0 +1,32 @@
import abc
class Environment(abc.ABC):
"""An Environment represents the robot and the environment it inhabits.
The primary contract of environments is that they can be queried for observations
about their state, and have actions applied to them to change that state.
"""
@abc.abstractmethod
def reset(self) -> None:
"""Reset the environment to its initial state.
This will be called once before starting each episode.
"""
@abc.abstractmethod
def is_episode_complete(self) -> bool:
"""Allow the environment to signal that the episode is complete.
This will be called after each step. It should return `True` if the episode is
complete (either successfully or unsuccessfully), and `False` otherwise.
"""
@abc.abstractmethod
def get_observation(self) -> dict:
"""Query the environment for the current state."""
@abc.abstractmethod
def apply_action(self, action: dict) -> None:
"""Take an action in the environment."""

View File

@@ -0,0 +1,92 @@
import logging
import threading
import time
from openpi_client.runtime import agent as _agent
from openpi_client.runtime import environment as _environment
from openpi_client.runtime import subscriber as _subscriber
class Runtime:
"""The core module orchestrating interactions between key components of the system."""
def __init__(
self,
environment: _environment.Environment,
agent: _agent.Agent,
subscribers: list[_subscriber.Subscriber],
max_hz: float = 0,
num_episodes: int = 1,
max_episode_steps: int = 0,
) -> None:
self._environment = environment
self._agent = agent
self._subscribers = subscribers
self._max_hz = max_hz
self._num_episodes = num_episodes
self._max_episode_steps = max_episode_steps
self._in_episode = False
self._episode_steps = 0
def run(self) -> None:
"""Runs the runtime loop continuously until stop() is called or the environment is done."""
for _ in range(self._num_episodes):
self._run_episode()
# Final reset, this is important for real environments to move the robot to its home position.
self._environment.reset()
def run_in_new_thread(self) -> threading.Thread:
"""Runs the runtime loop in a new thread."""
thread = threading.Thread(target=self.run)
thread.start()
return thread
def mark_episode_complete(self) -> None:
"""Marks the end of an episode."""
self._in_episode = False
def _run_episode(self) -> None:
"""Runs a single episode."""
logging.info("Starting episode...")
self._environment.reset()
self._agent.reset()
for subscriber in self._subscribers:
subscriber.on_episode_start()
self._in_episode = True
self._episode_steps = 0
step_time = 1 / self._max_hz if self._max_hz > 0 else 0
last_step_time = time.time()
while self._in_episode:
self._step()
self._episode_steps += 1
# Sleep to maintain the desired frame rate
now = time.time()
dt = now - last_step_time
if dt < step_time:
time.sleep(step_time - dt)
last_step_time = time.time()
else:
last_step_time = now
logging.info("Episode completed.")
for subscriber in self._subscribers:
subscriber.on_episode_end()
def _step(self) -> None:
"""A single step of the runtime loop."""
observation = self._environment.get_observation()
action = self._agent.get_action(observation)
self._environment.apply_action(action)
for subscriber in self._subscribers:
subscriber.on_step(observation, action)
if self._environment.is_episode_complete() or (
self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps
):
self.mark_episode_complete()

View File

@@ -0,0 +1,20 @@
import abc
class Subscriber(abc.ABC):
"""Subscribes to events in the runtime.
Subscribers can be used to save data, visualize, etc.
"""
@abc.abstractmethod
def on_episode_start(self) -> None:
"""Called when an episode starts."""
@abc.abstractmethod
def on_step(self, observation: dict, action: dict) -> None:
"""Append a step to the episode."""
@abc.abstractmethod
def on_episode_end(self) -> None:
"""Called when an episode ends."""

View File

@@ -0,0 +1,49 @@
import logging
import time
from typing import Dict, Tuple
import websockets.sync.client
from typing_extensions import override
from openpi_client import base_policy as _base_policy
from openpi_client import msgpack_numpy
class WebsocketClientPolicy(_base_policy.BasePolicy):
"""Implements the Policy interface by communicating with a server over websocket.
See WebsocketPolicyServer for a corresponding server implementation.
"""
def __init__(self, host: str = "0.0.0.0", port: int = 8000) -> None:
self._uri = f"ws://{host}:{port}"
self._packer = msgpack_numpy.Packer()
self._ws, self._server_metadata = self._wait_for_server()
def get_server_metadata(self) -> Dict:
return self._server_metadata
def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
logging.info(f"Waiting for server at {self._uri}...")
while True:
try:
conn = websockets.sync.client.connect(self._uri, compression=None, max_size=None)
metadata = msgpack_numpy.unpackb(conn.recv())
return conn, metadata
except ConnectionRefusedError:
logging.info("Still waiting for server...")
time.sleep(5)
@override
def infer(self, obs: Dict) -> Dict: # noqa: UP006
data = self._packer.pack(obs)
self._ws.send(data)
response = self._ws.recv()
if isinstance(response, str):
# we're expecting bytes; if the server sends a string, it's an error.
raise RuntimeError(f"Error in inference server:\n{response}")
return msgpack_numpy.unpackb(response)
@override
def reset(self) -> None:
pass