52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
import dataclasses
|
|
import logging
|
|
|
|
from openpi_client import action_chunk_broker
|
|
from openpi_client import websocket_client_policy as _websocket_client_policy
|
|
from openpi_client.runtime import runtime as _runtime
|
|
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
|
import tyro
|
|
|
|
from examples.aloha_real import env as _env
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Args:
|
|
host: str = "0.0.0.0"
|
|
port: int = 8000
|
|
|
|
action_horizon: int = 25
|
|
|
|
num_episodes: int = 1
|
|
max_episode_steps: int = 1000
|
|
|
|
|
|
def main(args: Args) -> None:
|
|
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
|
|
host=args.host,
|
|
port=args.port,
|
|
)
|
|
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
|
|
|
|
metadata = ws_client_policy.get_server_metadata()
|
|
runtime = _runtime.Runtime(
|
|
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
|
|
agent=_policy_agent.PolicyAgent(
|
|
policy=action_chunk_broker.ActionChunkBroker(
|
|
policy=ws_client_policy,
|
|
action_horizon=args.action_horizon,
|
|
)
|
|
),
|
|
subscribers=[],
|
|
max_hz=50,
|
|
num_episodes=args.num_episodes,
|
|
max_episode_steps=args.max_episode_steps,
|
|
)
|
|
|
|
runtime.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO, force=True)
|
|
tyro.cli(main)
|