37 lines
958 B
Python
Executable File
37 lines
958 B
Python
Executable File
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from openpi_client.runtime import subscriber as _subscriber
|
|
from typing_extensions import override
|
|
|
|
|
|
class VideoDisplay(_subscriber.Subscriber):
|
|
"""Displays video frames."""
|
|
|
|
def __init__(self) -> None:
|
|
self._ax: plt.Axes | None = None
|
|
self._plt_img: plt.Image | None = None
|
|
|
|
@override
|
|
def on_episode_start(self) -> None:
|
|
plt.ion()
|
|
self._ax = plt.subplot()
|
|
self._plt_img = None
|
|
|
|
@override
|
|
def on_step(self, observation: dict, action: dict) -> None:
|
|
assert self._ax is not None
|
|
|
|
im = observation["image"][0] # [C, H, W]
|
|
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
|
|
|
if self._plt_img is None:
|
|
self._plt_img = self._ax.imshow(im)
|
|
else:
|
|
self._plt_img.set_data(im)
|
|
plt.pause(0.001)
|
|
|
|
@override
|
|
def on_episode_end(self) -> None:
|
|
plt.ioff()
|
|
plt.close()
|