Initial commit

This commit is contained in:
Ury Zhilinsky
2025-02-03 21:43:26 -08:00
commit 231a1cf7ca
121 changed files with 16349 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
import abc
class Agent(abc.ABC):
"""An Agent is the thing with agency, i.e. the entity that makes decisions.
Agents receive observations about the state of the world, and return actions
to take in response.
"""
@abc.abstractmethod
def get_action(self, observation: dict) -> dict:
"""Query the agent for the next action."""
@abc.abstractmethod
def reset(self) -> None:
"""Reset the agent to its initial state."""

View File

@@ -0,0 +1,18 @@
from typing_extensions import override
from openpi_client import base_policy as _base_policy
from openpi_client.runtime import agent as _agent
class PolicyAgent(_agent.Agent):
"""An agent that uses a policy to determine actions."""
def __init__(self, policy: _base_policy.BasePolicy) -> None:
self._policy = policy
@override
def get_action(self, observation: dict) -> dict:
return self._policy.infer(observation)
def reset(self) -> None:
self._policy.reset()

View File

@@ -0,0 +1,32 @@
import abc
class Environment(abc.ABC):
"""An Environment represents the robot and the environment it inhabits.
The primary contract of environments is that they can be queried for observations
about their state, and have actions applied to them to change that state.
"""
@abc.abstractmethod
def reset(self) -> None:
"""Reset the environment to its initial state.
This will be called once before starting each episode.
"""
@abc.abstractmethod
def is_episode_complete(self) -> bool:
"""Allow the environment to signal that the episode is complete.
This will be called after each step. It should return `True` if the episode is
complete (either successfully or unsuccessfully), and `False` otherwise.
"""
@abc.abstractmethod
def get_observation(self) -> dict:
"""Query the environment for the current state."""
@abc.abstractmethod
def apply_action(self, action: dict) -> None:
"""Take an action in the environment."""

View File

@@ -0,0 +1,92 @@
import logging
import threading
import time
from openpi_client.runtime import agent as _agent
from openpi_client.runtime import environment as _environment
from openpi_client.runtime import subscriber as _subscriber
class Runtime:
"""The core module orchestrating interactions between key components of the system."""
def __init__(
self,
environment: _environment.Environment,
agent: _agent.Agent,
subscribers: list[_subscriber.Subscriber],
max_hz: float = 0,
num_episodes: int = 1,
max_episode_steps: int = 0,
) -> None:
self._environment = environment
self._agent = agent
self._subscribers = subscribers
self._max_hz = max_hz
self._num_episodes = num_episodes
self._max_episode_steps = max_episode_steps
self._in_episode = False
self._episode_steps = 0
def run(self) -> None:
"""Runs the runtime loop continuously until stop() is called or the environment is done."""
for _ in range(self._num_episodes):
self._run_episode()
# Final reset, this is important for real environments to move the robot to its home position.
self._environment.reset()
def run_in_new_thread(self) -> threading.Thread:
"""Runs the runtime loop in a new thread."""
thread = threading.Thread(target=self.run)
thread.start()
return thread
def mark_episode_complete(self) -> None:
"""Marks the end of an episode."""
self._in_episode = False
def _run_episode(self) -> None:
"""Runs a single episode."""
logging.info("Starting episode...")
self._environment.reset()
self._agent.reset()
for subscriber in self._subscribers:
subscriber.on_episode_start()
self._in_episode = True
self._episode_steps = 0
step_time = 1 / self._max_hz if self._max_hz > 0 else 0
last_step_time = time.time()
while self._in_episode:
self._step()
self._episode_steps += 1
# Sleep to maintain the desired frame rate
now = time.time()
dt = now - last_step_time
if dt < step_time:
time.sleep(step_time - dt)
last_step_time = time.time()
else:
last_step_time = now
logging.info("Episode completed.")
for subscriber in self._subscribers:
subscriber.on_episode_end()
def _step(self) -> None:
"""A single step of the runtime loop."""
observation = self._environment.get_observation()
action = self._agent.get_action(observation)
self._environment.apply_action(action)
for subscriber in self._subscribers:
subscriber.on_step(observation, action)
if self._environment.is_episode_complete() or (
self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps
):
self.mark_episode_complete()

View File

@@ -0,0 +1,20 @@
import abc
class Subscriber(abc.ABC):
"""Subscribes to events in the runtime.
Subscribers can be used to save data, visualize, etc.
"""
@abc.abstractmethod
def on_episode_start(self) -> None:
"""Called when an episode starts."""
@abc.abstractmethod
def on_step(self, observation: dict, action: dict) -> None:
"""Append a step to the episode."""
@abc.abstractmethod
def on_episode_end(self) -> None:
"""Called when an episode ends."""