Initial commit
This commit is contained in:
1
packages/openpi-client/src/openpi_client/__init__.py
Normal file
1
packages/openpi-client/src/openpi_client/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.1.0"
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import tree
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
|
||||
|
||||
class ActionChunkBroker(_base_policy.BasePolicy):
|
||||
"""Wraps a policy to return action chunks one-at-a-time.
|
||||
|
||||
Assumes that the first dimension of all action fields is the chunk size.
|
||||
|
||||
A new inference call to the inner policy is only made when the current
|
||||
list of chunks is exhausted.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
|
||||
self._policy = policy
|
||||
|
||||
self._action_horizon = action_horizon
|
||||
self._cur_step: int = 0
|
||||
|
||||
self._last_results: Dict[str, np.ndarray] | None = None
|
||||
|
||||
@override
|
||||
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
||||
if self._last_results is None:
|
||||
self._last_results = self._policy.infer(obs)
|
||||
self._cur_step = 0
|
||||
|
||||
results = tree.map_structure(lambda x: x[self._cur_step, ...], self._last_results)
|
||||
self._cur_step += 1
|
||||
|
||||
if self._cur_step >= self._action_horizon:
|
||||
self._last_results = None
|
||||
|
||||
return results
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self._policy.reset()
|
||||
self._last_results = None
|
||||
self._cur_step = 0
|
||||
12
packages/openpi-client/src/openpi_client/base_policy.py
Normal file
12
packages/openpi-client/src/openpi_client/base_policy.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import abc
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class BasePolicy(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def infer(self, obs: Dict) -> Dict:
|
||||
"""Infer actions from observations."""
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the policy to its initial state."""
|
||||
pass
|
||||
58
packages/openpi-client/src/openpi_client/image_tools.py
Normal file
58
packages/openpi-client/src/openpi_client/image_tools.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def convert_to_uint8(img: np.ndarray) -> np.ndarray:
|
||||
"""Converts an image to uint8 if it is a float image.
|
||||
|
||||
This is important for reducing the size of the image when sending it over the network.
|
||||
"""
|
||||
if np.issubdtype(img.dtype, np.floating):
|
||||
img = (255 * img).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
|
||||
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
|
||||
|
||||
Args:
|
||||
images: A batch of images in [..., height, width, channel] format.
|
||||
height: The target height of the image.
|
||||
width: The target width of the image.
|
||||
method: The interpolation method to use. Default is bilinear.
|
||||
|
||||
Returns:
|
||||
The resized images in [..., height, width, channel].
|
||||
"""
|
||||
# If the images are already the correct size, return them as is.
|
||||
if images.shape[-3:-1] == (height, width):
|
||||
return images
|
||||
|
||||
original_shape = images.shape
|
||||
|
||||
images = images.reshape(-1, *original_shape[-3:])
|
||||
resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
|
||||
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
|
||||
|
||||
|
||||
def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
|
||||
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
|
||||
width without distortion by padding with zeros.
|
||||
|
||||
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
|
||||
"""
|
||||
cur_width, cur_height = image.size
|
||||
if cur_width == width and cur_height == height:
|
||||
return image # No need to resize if the image is already the correct size.
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_image = image.resize((resized_width, resized_height), resample=method)
|
||||
|
||||
zero_image = Image.new(resized_image.mode, (width, height), 0)
|
||||
pad_height = max(0, int((height - resized_height) / 2))
|
||||
pad_width = max(0, int((width - resized_width) / 2))
|
||||
zero_image.paste(resized_image, (pad_width, pad_height))
|
||||
assert zero_image.size == (width, height)
|
||||
return zero_image
|
||||
37
packages/openpi-client/src/openpi_client/image_tools_test.py
Normal file
37
packages/openpi-client/src/openpi_client/image_tools_test.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
|
||||
import openpi_client.image_tools as image_tools
|
||||
|
||||
|
||||
def test_resize_with_pad_shapes():
|
||||
# Test case 1: Resize image with larger dimensions
|
||||
images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
|
||||
height = 20
|
||||
width = 20
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (2, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 2: Resize image with smaller dimensions
|
||||
images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
|
||||
height = 15
|
||||
width = 15
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (3, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with the same dimensions
|
||||
images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
|
||||
height = 50
|
||||
width = 50
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with odd-numbered padding
|
||||
images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
|
||||
height = 60
|
||||
width = 80
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
57
packages/openpi-client/src/openpi_client/msgpack_numpy.py
Normal file
57
packages/openpi-client/src/openpi_client/msgpack_numpy.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Adds NumPy array support to msgpack.
|
||||
|
||||
msgpack is good for (de)serializing data over a network for multiple reasons:
|
||||
- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
|
||||
- msgpack is widely used and has good cross-language support
|
||||
- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
|
||||
languages like Python and JavaScript
|
||||
- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
|
||||
than pickle for serializing large arrays using the below strategy
|
||||
|
||||
The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
|
||||
that it falls back to pickle for object arrays.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pack_array(obj):
|
||||
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
|
||||
raise ValueError(f"Unsupported dtype: {obj.dtype}")
|
||||
|
||||
if isinstance(obj, np.ndarray):
|
||||
return {
|
||||
b"__ndarray__": True,
|
||||
b"data": obj.tobytes(),
|
||||
b"dtype": obj.dtype.str,
|
||||
b"shape": obj.shape,
|
||||
}
|
||||
|
||||
if isinstance(obj, np.generic):
|
||||
return {
|
||||
b"__npgeneric__": True,
|
||||
b"data": obj.item(),
|
||||
b"dtype": obj.dtype.str,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def unpack_array(obj):
|
||||
if b"__ndarray__" in obj:
|
||||
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
|
||||
|
||||
if b"__npgeneric__" in obj:
|
||||
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
Packer = functools.partial(msgpack.Packer, default=pack_array)
|
||||
packb = functools.partial(msgpack.packb, default=pack_array)
|
||||
|
||||
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
|
||||
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
|
||||
@@ -0,0 +1,45 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tree
|
||||
|
||||
from openpi_client import msgpack_numpy
|
||||
|
||||
|
||||
def _check(expected, actual):
|
||||
if isinstance(expected, np.ndarray):
|
||||
assert expected.shape == actual.shape
|
||||
assert expected.dtype == actual.dtype
|
||||
assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
|
||||
else:
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data",
|
||||
[
|
||||
1, # int
|
||||
1.0, # float
|
||||
"hello", # string
|
||||
np.bool_(True), # boolean scalar
|
||||
np.array([1, 2, 3])[0], # int scalar
|
||||
np.str_("asdf"), # string scalar
|
||||
[1, 2, 3], # list
|
||||
{"key": "value"}, # dict
|
||||
{"key": [1, 2, 3]}, # nested dict
|
||||
np.array(1.0), # 0D array
|
||||
np.array([1, 2, 3], dtype=np.int32), # 1D integer array
|
||||
np.array(["asdf", "qwer"]), # string array
|
||||
np.array([True, False]), # boolean array
|
||||
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
|
||||
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
|
||||
np.array([np.nan, np.inf, -np.inf]), # special float values
|
||||
{"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
|
||||
[np.array([1, 2]), np.array([3, 4])], # list of arrays
|
||||
np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
|
||||
np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
|
||||
],
|
||||
)
|
||||
def test_pack_unpack(data):
|
||||
packed = msgpack_numpy.packb(data)
|
||||
unpacked = msgpack_numpy.unpackb(packed)
|
||||
tree.map_structure(_check, data, unpacked)
|
||||
17
packages/openpi-client/src/openpi_client/runtime/agent.py
Normal file
17
packages/openpi-client/src/openpi_client/runtime/agent.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Agent(abc.ABC):
|
||||
"""An Agent is the thing with agency, i.e. the entity that makes decisions.
|
||||
|
||||
Agents receive observations about the state of the world, and return actions
|
||||
to take in response.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_action(self, observation: dict) -> dict:
|
||||
"""Query the agent for the next action."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the agent to its initial state."""
|
||||
@@ -0,0 +1,18 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client.runtime import agent as _agent
|
||||
|
||||
|
||||
class PolicyAgent(_agent.Agent):
|
||||
"""An agent that uses a policy to determine actions."""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy) -> None:
|
||||
self._policy = policy
|
||||
|
||||
@override
|
||||
def get_action(self, observation: dict) -> dict:
|
||||
return self._policy.infer(observation)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._policy.reset()
|
||||
@@ -0,0 +1,32 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Environment(abc.ABC):
|
||||
"""An Environment represents the robot and the environment it inhabits.
|
||||
|
||||
The primary contract of environments is that they can be queried for observations
|
||||
about their state, and have actions applied to them to change that state.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the environment to its initial state.
|
||||
|
||||
This will be called once before starting each episode.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_episode_complete(self) -> bool:
|
||||
"""Allow the environment to signal that the episode is complete.
|
||||
|
||||
This will be called after each step. It should return `True` if the episode is
|
||||
complete (either successfully or unsuccessfully), and `False` otherwise.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_observation(self) -> dict:
|
||||
"""Query the environment for the current state."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def apply_action(self, action: dict) -> None:
|
||||
"""Take an action in the environment."""
|
||||
92
packages/openpi-client/src/openpi_client/runtime/runtime.py
Normal file
92
packages/openpi-client/src/openpi_client/runtime/runtime.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from openpi_client.runtime import agent as _agent
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""The core module orchestrating interactions between key components of the system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
environment: _environment.Environment,
|
||||
agent: _agent.Agent,
|
||||
subscribers: list[_subscriber.Subscriber],
|
||||
max_hz: float = 0,
|
||||
num_episodes: int = 1,
|
||||
max_episode_steps: int = 0,
|
||||
) -> None:
|
||||
self._environment = environment
|
||||
self._agent = agent
|
||||
self._subscribers = subscribers
|
||||
self._max_hz = max_hz
|
||||
self._num_episodes = num_episodes
|
||||
self._max_episode_steps = max_episode_steps
|
||||
|
||||
self._in_episode = False
|
||||
self._episode_steps = 0
|
||||
|
||||
def run(self) -> None:
|
||||
"""Runs the runtime loop continuously until stop() is called or the environment is done."""
|
||||
for _ in range(self._num_episodes):
|
||||
self._run_episode()
|
||||
|
||||
# Final reset, this is important for real environments to move the robot to its home position.
|
||||
self._environment.reset()
|
||||
|
||||
def run_in_new_thread(self) -> threading.Thread:
|
||||
"""Runs the runtime loop in a new thread."""
|
||||
thread = threading.Thread(target=self.run)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
def mark_episode_complete(self) -> None:
|
||||
"""Marks the end of an episode."""
|
||||
self._in_episode = False
|
||||
|
||||
def _run_episode(self) -> None:
|
||||
"""Runs a single episode."""
|
||||
logging.info("Starting episode...")
|
||||
self._environment.reset()
|
||||
self._agent.reset()
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_episode_start()
|
||||
|
||||
self._in_episode = True
|
||||
self._episode_steps = 0
|
||||
step_time = 1 / self._max_hz if self._max_hz > 0 else 0
|
||||
last_step_time = time.time()
|
||||
|
||||
while self._in_episode:
|
||||
self._step()
|
||||
self._episode_steps += 1
|
||||
|
||||
# Sleep to maintain the desired frame rate
|
||||
now = time.time()
|
||||
dt = now - last_step_time
|
||||
if dt < step_time:
|
||||
time.sleep(step_time - dt)
|
||||
last_step_time = time.time()
|
||||
else:
|
||||
last_step_time = now
|
||||
|
||||
logging.info("Episode completed.")
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_episode_end()
|
||||
|
||||
def _step(self) -> None:
|
||||
"""A single step of the runtime loop."""
|
||||
observation = self._environment.get_observation()
|
||||
action = self._agent.get_action(observation)
|
||||
self._environment.apply_action(action)
|
||||
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_step(observation, action)
|
||||
|
||||
if self._environment.is_episode_complete() or (
|
||||
self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps
|
||||
):
|
||||
self.mark_episode_complete()
|
||||
@@ -0,0 +1,20 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Subscriber(abc.ABC):
|
||||
"""Subscribes to events in the runtime.
|
||||
|
||||
Subscribers can be used to save data, visualize, etc.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_episode_start(self) -> None:
|
||||
"""Called when an episode starts."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
"""Append a step to the episode."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_episode_end(self) -> None:
|
||||
"""Called when an episode ends."""
|
||||
@@ -0,0 +1,49 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import websockets.sync.client
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client import msgpack_numpy
|
||||
|
||||
|
||||
class WebsocketClientPolicy(_base_policy.BasePolicy):
|
||||
"""Implements the Policy interface by communicating with a server over websocket.
|
||||
|
||||
See WebsocketPolicyServer for a corresponding server implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8000) -> None:
|
||||
self._uri = f"ws://{host}:{port}"
|
||||
self._packer = msgpack_numpy.Packer()
|
||||
self._ws, self._server_metadata = self._wait_for_server()
|
||||
|
||||
def get_server_metadata(self) -> Dict:
|
||||
return self._server_metadata
|
||||
|
||||
def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
|
||||
logging.info(f"Waiting for server at {self._uri}...")
|
||||
while True:
|
||||
try:
|
||||
conn = websockets.sync.client.connect(self._uri, compression=None, max_size=None)
|
||||
metadata = msgpack_numpy.unpackb(conn.recv())
|
||||
return conn, metadata
|
||||
except ConnectionRefusedError:
|
||||
logging.info("Still waiting for server...")
|
||||
time.sleep(5)
|
||||
|
||||
@override
|
||||
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
||||
data = self._packer.pack(obs)
|
||||
self._ws.send(data)
|
||||
response = self._ws.recv()
|
||||
if isinstance(response, str):
|
||||
# we're expecting bytes; if the server sends a string, it's an error.
|
||||
raise RuntimeError(f"Error in inference server:\n{response}")
|
||||
return msgpack_numpy.unpackb(response)
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
Reference in New Issue
Block a user