Initial commit
This commit is contained in:
70
examples/aloha_real/Dockerfile
Normal file
70
examples/aloha_real/Dockerfile
Normal file
@@ -0,0 +1,70 @@
|
||||
# Dockerfile for the Aloha real environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
|
||||
|
||||
FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cmake \
|
||||
curl \
|
||||
libffi-dev \
|
||||
python3-rosdep \
|
||||
python3-rosinstall \
|
||||
python3-rosinstall-generator \
|
||||
whiptail \
|
||||
git \
|
||||
wget \
|
||||
openssh-client \
|
||||
ros-noetic-cv-bridge \
|
||||
ros-noetic-usb-cam \
|
||||
ros-noetic-realsense2-camera \
|
||||
keyboard-configuration
|
||||
|
||||
WORKDIR /root
|
||||
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
|
||||
RUN chmod +x xsarm_amd64_install.sh
|
||||
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
|
||||
|
||||
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
|
||||
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
|
||||
|
||||
# Install python 3.10 because this ROS image comes with 3.8
|
||||
RUN mkdir /python && \
|
||||
cd /python && \
|
||||
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
|
||||
tar -zxvf Python-3.10.14.tgz && \
|
||||
cd Python-3.10.14 && \
|
||||
ls -lhR && \
|
||||
./configure --enable-optimizations && \
|
||||
make install && \
|
||||
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
cd ~ && rm -rf /python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
|
||||
ENV UV_HTTP_TIMEOUT=120
|
||||
ENV UV_LINK_MODE=copy
|
||||
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
|
||||
WORKDIR /app
|
||||
|
||||
# Create an entrypoint script to run the setup commands, followed by the command passed in.
|
||||
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
|
||||
#!/bin/bash
|
||||
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
|
||||
EOF
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["python3", "/app/examples/aloha_real/main.py"]
|
||||
73
examples/aloha_real/README.md
Normal file
73
examples/aloha_real/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Run Aloha (Real Robot)
|
||||
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha).
|
||||
|
||||
## Prerequisites
|
||||
|
||||
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
|
||||
|
||||
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
|
||||
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA --default_prompt='toast out of toaster'"
|
||||
docker compose -f examples/aloha_real/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_real/.venv
|
||||
source examples/aloha_real/.venv/bin/activate
|
||||
uv pip sync examples/aloha_real/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the robot
|
||||
python examples/aloha_real/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
roslaunch --wait aloha ros_nodes.launch
|
||||
```
|
||||
|
||||
Terminal window 3:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env ALOHA --default_prompt='toast out of toaster'
|
||||
```
|
||||
|
||||
## Model Guide
|
||||
The Pi0 Base Model is an out-of-the-box model for general tasks. You can find more details in the [technical report](https://www.physicalintelligence.company/download/pi0.pdf).
|
||||
|
||||
While we strongly recommend fine-tuning the model to your own data to adapt it to particular tasks, it may be possible to prompt the model to attempt some tasks that were in the pre-training data. For example, below is a video of the model attempting the "toast out of toaster" task.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/Physical-Intelligence/openpi/blob/main/examples/aloha_real/toast.gif" alt="toast out of toaster"/>
|
||||
</p>
|
||||
|
||||
## Training on your own Aloha dataset
|
||||
|
||||
OpenPI suppports training on data collected in the default aloha hdf5 format. To do so you must first convert the data to the huggingface format. We include `scripts/aloha_hd5.py` to help you do this. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` and replace repo id with the id assigned to your dataset during conversion.
|
||||
|
||||
```python
|
||||
TrainConfig(
|
||||
name=<your-config-name>,
|
||||
data=LeRobotAlohaDataConfig(
|
||||
repo_id=<your-repo-id>,
|
||||
delta_action_mask=[True] * 6 + [False] + [True] * 6 + [False],
|
||||
),
|
||||
),
|
||||
```
|
||||
|
||||
Run the training script:
|
||||
|
||||
```bash
|
||||
uv run scripts/train.py <your-config-name>
|
||||
```
|
||||
63
examples/aloha_real/compose.yml
Normal file
63
examples/aloha_real/compose.yml
Normal file
@@ -0,0 +1,63 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_real/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- aloha_ros_nodes
|
||||
- ros_master
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
aloha_ros_nodes:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- ros_master
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- /dev:/dev
|
||||
command: roslaunch --wait aloha ros_nodes.launch
|
||||
|
||||
ros_master:
|
||||
image: ros:noetic-robot
|
||||
network_mode: host
|
||||
privileged: true
|
||||
command:
|
||||
- roscore
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
71
examples/aloha_real/constants.py
Normal file
71
examples/aloha_real/constants.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
|
||||
### Task parameters
|
||||
|
||||
### ALOHA fixed constants
|
||||
DT = 0.001
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
|
||||
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
|
||||
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
MASTER_POS2JOINT = (
|
||||
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
PUPPET_POS2JOINT = (
|
||||
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
52
examples/aloha_real/env.py
Normal file
52
examples/aloha_real/env.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import einops
|
||||
import numpy as np
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
from examples.aloha_real import real_env as _real_env
|
||||
|
||||
|
||||
class AlohaRealEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot on real hardware."""
|
||||
|
||||
def __init__(self, render_height: int = 480, render_width: int = 640) -> None:
|
||||
self._env = _real_env.make_real_env(init_node=True)
|
||||
self._render_height = render_height
|
||||
self._render_width = render_width
|
||||
|
||||
self._ts = None
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self._ts = self._env.reset()
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
return False
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._ts is None:
|
||||
raise RuntimeError("Timestep is not set. Call reset() first.")
|
||||
|
||||
obs = self._ts.observation
|
||||
for k in list(obs["images"].keys()):
|
||||
if "_depth" in k:
|
||||
del obs["images"][k]
|
||||
|
||||
images = []
|
||||
for cam_name in obs["images"]:
|
||||
curr_image = obs["images"][cam_name]
|
||||
curr_image = einops.rearrange(curr_image, "h w c -> c h w")
|
||||
images.append(curr_image)
|
||||
stacked_images = np.stack(images, axis=0).astype(np.uint8)
|
||||
|
||||
# TODO: Consider removing these transformations.
|
||||
return {
|
||||
"qpos": obs["qpos"],
|
||||
"image": stacked_images,
|
||||
}
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
self._ts = self._env.step(action["qpos"])
|
||||
42
examples/aloha_real/main.py
Normal file
42
examples/aloha_real/main.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import tyro
|
||||
|
||||
from examples.aloha_real import env as _env
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
action_horizon: int = 25
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaRealEnvironment(),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=_websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
),
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[],
|
||||
max_hz=50,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
167
examples/aloha_real/real_env.py
Normal file
167
examples/aloha_real/real_env.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
import collections
|
||||
import time
|
||||
|
||||
import dm_env
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
|
||||
from examples.aloha_real import constants
|
||||
from examples.aloha_real import robot_utils
|
||||
|
||||
|
||||
class RealEnv:
|
||||
"""
|
||||
Environment for real robot bi-manual manipulation
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
def __init__(self, init_node, *, setup_robots: bool = True):
|
||||
self.puppet_bot_left = InterbotixManipulatorXS(
|
||||
robot_model="vx300s",
|
||||
group_name="arm",
|
||||
gripper_name="gripper",
|
||||
robot_name="puppet_left",
|
||||
init_node=init_node,
|
||||
)
|
||||
self.puppet_bot_right = InterbotixManipulatorXS(
|
||||
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
||||
)
|
||||
if setup_robots:
|
||||
self.setup_robots()
|
||||
|
||||
self.recorder_left = robot_utils.Recorder("left", init_node=False)
|
||||
self.recorder_right = robot_utils.Recorder("right", init_node=False)
|
||||
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
|
||||
self.gripper_command = JointSingleCommand(name="gripper")
|
||||
|
||||
def setup_robots(self):
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_left)
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_right)
|
||||
|
||||
def get_qpos(self):
|
||||
left_qpos_raw = self.recorder_left.qpos
|
||||
right_qpos_raw = self.recorder_right.qpos
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
right_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
def get_qvel(self):
|
||||
left_qvel_raw = self.recorder_left.qvel
|
||||
right_qvel_raw = self.recorder_right.qvel
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
||||
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
def get_effort(self):
|
||||
left_effort_raw = self.recorder_left.effort
|
||||
right_effort_raw = self.recorder_right.effort
|
||||
left_robot_effort = left_effort_raw[:7]
|
||||
right_robot_effort = right_effort_raw[:7]
|
||||
return np.concatenate([left_robot_effort, right_robot_effort])
|
||||
|
||||
def get_images(self):
|
||||
return self.image_recorder.get_images()
|
||||
|
||||
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
||||
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
||||
self.gripper_command.cmd = left_gripper_desired_joint
|
||||
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
|
||||
right_gripper_desired_pos_normalized
|
||||
)
|
||||
self.gripper_command.cmd = right_gripper_desired_joint
|
||||
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
def _reset_joints(self):
|
||||
# reset_position = START_ARM_POSE[:6]
|
||||
reset_position = [0, -1.5, 1.5, 0, 0, 0]
|
||||
robot_utils.move_arms(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [reset_position, reset_position], move_time=1
|
||||
)
|
||||
|
||||
def _reset_gripper(self):
|
||||
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
|
||||
)
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
|
||||
)
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def get_reward(self):
|
||||
return 0
|
||||
|
||||
def reset(self, *, fake=False):
|
||||
if not fake:
|
||||
# Reboot puppet robot gripper motors
|
||||
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self._reset_joints()
|
||||
self._reset_gripper()
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
state_len = int(len(action) / 2)
|
||||
left_action = action[:state_len]
|
||||
right_action = action[state_len:]
|
||||
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
||||
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
||||
self.set_gripper_pose(left_action[-1], right_action[-1])
|
||||
time.sleep(constants.DT)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
||||
# Arm actions
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# Gripper actions
|
||||
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
||||
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def make_real_env(init_node, *, setup_robots: bool = True) -> RealEnv:
|
||||
return RealEnv(init_node, setup_robots=setup_robots)
|
||||
18
examples/aloha_real/requirements.in
Normal file
18
examples/aloha_real/requirements.in
Normal file
@@ -0,0 +1,18 @@
|
||||
Pillow
|
||||
dm_control
|
||||
einops
|
||||
h5py
|
||||
matplotlib
|
||||
modern_robotics
|
||||
msgpack
|
||||
numpy
|
||||
opencv-python
|
||||
packaging
|
||||
pexpect
|
||||
pyquaternion
|
||||
pyrealsense2
|
||||
pyyaml
|
||||
requests
|
||||
rospkg
|
||||
tyro
|
||||
websockets
|
||||
156
examples/aloha_real/requirements.txt
Normal file
156
examples/aloha_real/requirements.txt
Normal file
@@ -0,0 +1,156 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
catkin-pkg==1.0.0
|
||||
# via rospkg
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
contourpy==1.1.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
distro==1.9.0
|
||||
# via rospkg
|
||||
dm-control==1.0.23
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
docutils==0.20.1
|
||||
# via catkin-pkg
|
||||
einops==0.8.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
h5py==3.11.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
idna==3.10
|
||||
# via requests
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.7.5
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
modern-robotics==1.1.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mujoco==3.2.3
|
||||
# via dm-control
|
||||
numpy==1.24.4
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# h5py
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# modern-robotics
|
||||
# mujoco
|
||||
# opencv-python
|
||||
# pyquaternion
|
||||
# scipy
|
||||
opencv-python==4.10.0.84
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
packaging==24.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
pexpect==4.9.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.1.4
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pyrealsense2==2.55.1.6486
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# matplotlib
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# rospkg
|
||||
requests==2.32.3
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
rospkg==1.5.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
scipy==1.10.1
|
||||
# via dm-control
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
zipp==3.20.2
|
||||
# via etils
|
||||
275
examples/aloha_real/robot_utils.py
Normal file
275
examples/aloha_real/robot_utils.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
from collections import deque
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
|
||||
from aloha.msg import RGBGrayscaleImage
|
||||
from cv_bridge import CvBridge
|
||||
from interbotix_xs_msgs.msg import JointGroupCommand
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
import rospy
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
from examples.aloha_real import constants
|
||||
|
||||
|
||||
class ImageRecorder:
|
||||
def __init__(self, init_node=True, is_debug=False):
|
||||
self.is_debug = is_debug
|
||||
self.bridge = CvBridge()
|
||||
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("image_recorder", anonymous=True)
|
||||
for cam_name in self.camera_names:
|
||||
setattr(self, f"{cam_name}_rgb_image", None)
|
||||
setattr(self, f"{cam_name}_depth_image", None)
|
||||
setattr(self, f"{cam_name}_timestamp", 0.0)
|
||||
if cam_name == "cam_high":
|
||||
callback_func = self.image_cb_cam_high
|
||||
elif cam_name == "cam_low":
|
||||
callback_func = self.image_cb_cam_low
|
||||
elif cam_name == "cam_left_wrist":
|
||||
callback_func = self.image_cb_cam_left_wrist
|
||||
elif cam_name == "cam_right_wrist":
|
||||
callback_func = self.image_cb_cam_right_wrist
|
||||
else:
|
||||
raise NotImplementedError
|
||||
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
|
||||
if self.is_debug:
|
||||
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
|
||||
|
||||
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
|
||||
time.sleep(0.5)
|
||||
|
||||
def image_cb(self, cam_name, data):
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_rgb_image",
|
||||
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
|
||||
)
|
||||
# setattr(
|
||||
# self,
|
||||
# f"{cam_name}_depth_image",
|
||||
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
|
||||
# )
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_timestamp",
|
||||
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
|
||||
)
|
||||
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
|
||||
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
|
||||
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
|
||||
if self.is_debug:
|
||||
getattr(self, f"{cam_name}_timestamps").append(
|
||||
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
|
||||
)
|
||||
|
||||
def image_cb_cam_high(self, data):
|
||||
cam_name = "cam_high"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_low(self, data):
|
||||
cam_name = "cam_low"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_left_wrist(self, data):
|
||||
cam_name = "cam_left_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_right_wrist(self, data):
|
||||
cam_name = "cam_right_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def get_images(self):
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
|
||||
time.sleep(0.00001)
|
||||
rgb_image = getattr(self, f"{cam_name}_rgb_image")
|
||||
depth_image = getattr(self, f"{cam_name}_depth_image")
|
||||
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
|
||||
image_dict[cam_name] = rgb_image
|
||||
image_dict[f"{cam_name}_depth"] = depth_image
|
||||
return image_dict
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
|
||||
print(f"{cam_name} {image_freq=:.2f}")
|
||||
print()
|
||||
|
||||
|
||||
class Recorder:
|
||||
def __init__(self, side, init_node=True, is_debug=False):
|
||||
self.secs = None
|
||||
self.nsecs = None
|
||||
self.qpos = None
|
||||
self.effort = None
|
||||
self.arm_command = None
|
||||
self.gripper_command = None
|
||||
self.is_debug = is_debug
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("recorder", anonymous=True)
|
||||
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_group",
|
||||
JointGroupCommand,
|
||||
self.puppet_arm_commands_cb,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_single",
|
||||
JointSingleCommand,
|
||||
self.puppet_gripper_commands_cb,
|
||||
)
|
||||
if self.is_debug:
|
||||
self.joint_timestamps = deque(maxlen=50)
|
||||
self.arm_command_timestamps = deque(maxlen=50)
|
||||
self.gripper_command_timestamps = deque(maxlen=50)
|
||||
time.sleep(0.1)
|
||||
|
||||
def puppet_state_cb(self, data):
|
||||
self.qpos = data.position
|
||||
self.qvel = data.velocity
|
||||
self.effort = data.effort
|
||||
self.data = data
|
||||
if self.is_debug:
|
||||
self.joint_timestamps.append(time.time())
|
||||
|
||||
def puppet_arm_commands_cb(self, data):
|
||||
self.arm_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.arm_command_timestamps.append(time.time())
|
||||
|
||||
def puppet_gripper_commands_cb(self, data):
|
||||
self.gripper_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.gripper_command_timestamps.append(time.time())
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
||||
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
||||
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
||||
|
||||
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
|
||||
|
||||
|
||||
def get_arm_joint_positions(bot):
|
||||
return bot.arm.core.joint_states.position[:6]
|
||||
|
||||
|
||||
def get_arm_gripper_positions(bot):
|
||||
return bot.gripper.core.joint_states.position[6]
|
||||
|
||||
|
||||
def move_arms(bot_list, target_pose_list, move_time=1):
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
for t in range(num_steps):
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def move_grippers(bot_list, target_pose_list, move_time):
|
||||
print(f"Moving grippers to {target_pose_list=}")
|
||||
gripper_command = JointSingleCommand(name="gripper")
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
|
||||
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
|
||||
for t in range(num_steps):
|
||||
d = {}
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
gripper_command.cmd = traj_list[bot_id][t]
|
||||
bot.gripper.core.pub_single.publish(gripper_command)
|
||||
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
|
||||
f.write(json.dumps(d) + "\n")
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def setup_puppet_bot(bot):
|
||||
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_on(bot)
|
||||
|
||||
|
||||
def setup_master_bot(bot):
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_off(bot)
|
||||
|
||||
|
||||
def set_standard_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def set_low_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def torque_off(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", False)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", False)
|
||||
|
||||
|
||||
def torque_on(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", True)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", True)
|
||||
|
||||
|
||||
# for DAgger
|
||||
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
|
||||
print("\nSyncing!")
|
||||
|
||||
# activate master arms
|
||||
torque_on(master_bot_left)
|
||||
torque_on(master_bot_right)
|
||||
|
||||
# get puppet arm positions
|
||||
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
|
||||
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
|
||||
|
||||
# get puppet gripper positions
|
||||
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
|
||||
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
|
||||
|
||||
# move master arms to puppet positions
|
||||
move_arms(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_qpos, puppet_right_qpos],
|
||||
move_time=1,
|
||||
)
|
||||
|
||||
# move master grippers to puppet positions
|
||||
move_grippers(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_gripper, puppet_right_gripper],
|
||||
move_time=1,
|
||||
)
|
||||
BIN
examples/aloha_real/toast.gif
Normal file
BIN
examples/aloha_real/toast.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 23 MiB |
36
examples/aloha_real/video_display.py
Normal file
36
examples/aloha_real/video_display.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoDisplay(_subscriber.Subscriber):
|
||||
"""Displays video frames."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ax: plt.Axes | None = None
|
||||
self._plt_img: plt.Image | None = None
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
plt.ion()
|
||||
self._ax = plt.subplot()
|
||||
self._plt_img = None
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
assert self._ax is not None
|
||||
|
||||
im = observation["image"][0] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
|
||||
if self._plt_img is None:
|
||||
self._plt_img = self._ax.imshow(im)
|
||||
else:
|
||||
self._plt_img.set_data(im)
|
||||
plt.pause(0.001)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
plt.ioff()
|
||||
plt.close()
|
||||
Reference in New Issue
Block a user