Initial commit
This commit is contained in:
36
examples/aloha_real/video_display.py
Normal file
36
examples/aloha_real/video_display.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoDisplay(_subscriber.Subscriber):
|
||||
"""Displays video frames."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ax: plt.Axes | None = None
|
||||
self._plt_img: plt.Image | None = None
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
plt.ion()
|
||||
self._ax = plt.subplot()
|
||||
self._plt_img = None
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
assert self._ax is not None
|
||||
|
||||
im = observation["image"][0] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
|
||||
if self._plt_img is None:
|
||||
self._plt_img = self._ax.imshow(im)
|
||||
else:
|
||||
self._plt_img.set_data(im)
|
||||
plt.pause(0.001)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
plt.ioff()
|
||||
plt.close()
|
||||
Reference in New Issue
Block a user