Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 943c0ebd64 | |||
| 589bbf9479 | |||
| 2779a2856b | |||
| a87dce9e3f | |||
| 3685542bf1 | |||
| 7c1699898b | |||
| b3e9e11e11 | |||
| b04e6e0c7b | |||
| 96804bc86c | |||
| ef45ea9649 | |||
| bc351a0134 | |||
| 68986f6fc0 | |||
| 2f124e34de | |||
| c28e774234 | |||
| 80b1a97e4c | |||
| f4fec8f51c | |||
| f4f82c916f | |||
| ecbe154709 | |||
| d00c154db9 | |||
| 55f284b306 | |||
| cf8df17d3a | |||
| e079566597 | |||
| 83d6419d70 | |||
| a0ec9e1cb1 | |||
| 3eede4447d | |||
| 9c6a7d9701 | |||
| 7b201773f3 |
154
realman.md
Normal file
154
realman.md
Normal file
@@ -0,0 +1,154 @@
|
||||
# Install
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
pip install edge-tts
|
||||
sudo apt install mpv -y
|
||||
|
||||
# pip uninstall numpy
|
||||
# pip install numpy==1.26.0
|
||||
# pip install pynput
|
||||
```
|
||||
|
||||
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
|
||||
```bash
|
||||
conda install ffmpeg=7.1.1 -c conda-forge
|
||||
# pip uninstall opencv-python
|
||||
# conda install "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
Install Realman SDK:
|
||||
```bash
|
||||
pip install Robotic_Arm==1.0.4.1
|
||||
pip install pygame
|
||||
```
|
||||
|
||||
# piper集成lerobot
|
||||
见lerobot_piper_tutorial/1. 🤗 LeRobot:新增机械臂的一般流程.pdf
|
||||
|
||||
# Teleoperate
|
||||
```python
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=realman_dual \
|
||||
--robot.inference_time=false \
|
||||
--control.type=teleoperate \
|
||||
--control.display_data=true
|
||||
```
|
||||
display_data=true turn on run.io else turn off.
|
||||
|
||||
# Record
|
||||
Set dataset root path
|
||||
```bash
|
||||
HF_USER=$PWD/data
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=realman \
|
||||
--robot.inference_time=false \
|
||||
--control.type=record \
|
||||
--control.fps=15 \
|
||||
--control.single_task="move" \
|
||||
--control.repo_id=maic/test \
|
||||
--control.num_episodes=2 \
|
||||
--control.warmup_time_s=2 \
|
||||
--control.episode_time_s=10 \
|
||||
--control.reset_time_s=10 \
|
||||
--control.play_sounds=true \
|
||||
--control.push_to_hub=false \
|
||||
--control.display_data=true
|
||||
```
|
||||
|
||||
Press right arrow -> at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
|
||||
Press left arrow <- at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
|
||||
Press escape ESC at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
|
||||
|
||||
# visualize
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id ${HF_USER}/test \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
# Replay
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=piper \
|
||||
--robot.inference_time=false \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
# Caution
|
||||
|
||||
1. In lerobots/common/datasets/video_utils, the vcodec is set to **libopenh264**, please find your vcodec by **ffmpeg -codecs**
|
||||
|
||||
|
||||
# Train
|
||||
具体的训练流程见lerobot_piper_tutorial/2. 🤗 AutoDL训练.pdf
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/jack \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_jack \
|
||||
--job_name=act_jack \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
# FT smolvla
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=maic/move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--output_dir=outputs/train/smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--job_name=smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=false \
|
||||
--steps=200000 \
|
||||
--batch_size=16
|
||||
|
||||
|
||||
# Inference
|
||||
还是使用control_robot.py中的record loop,配置 **--robot.inference_time=true** 可以将手柄移出。
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=realman \
|
||||
--robot.inference_time=true \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="move the bottle into ultrasonic device with realman single" \
|
||||
--control.repo_id=maic/move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--control.num_episodes=1 \
|
||||
--control.warmup_time_s=2 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=10 \
|
||||
--control.push_to_hub=false \
|
||||
--control.policy.path=outputs/train/act_move_the_bottle_into_ultrasonic_device_with_realman_single/checkpoints/100000/pretrained_model
|
||||
```
|
||||
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=realman \
|
||||
--robot.inference_time=true \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="move the bottle into ultrasonic device with realman single" \
|
||||
--control.repo_id=maic/eval_smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--control.num_episodes=1 \
|
||||
--control.warmup_time_s=2 \
|
||||
--control.episode_time_s=60 \
|
||||
--control.reset_time_s=10 \
|
||||
--control.push_to_hub=false \
|
||||
--control.policy.path=outputs/train/smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single/checkpoints/160000/pretrained_model \
|
||||
--control.display_data=true
|
||||
```
|
||||
31
realman_src/dual_arm_connect_test.py
Normal file
31
realman_src/dual_arm_connect_test.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from Robotic_Arm.rm_robot_interface import *
|
||||
|
||||
armleft = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
armright = RoboticArm()
|
||||
|
||||
|
||||
lefthandle = armleft.rm_create_robot_arm("169.254.128.18", 8080)
|
||||
print("机械臂ID:", lefthandle.id)
|
||||
righthandle = armright.rm_create_robot_arm("169.254.128.19", 8080)
|
||||
print("机械臂ID:", righthandle.id)
|
||||
|
||||
# software_info = armleft.rm_get_arm_software_info()
|
||||
# if software_info[0] == 0:
|
||||
# print("\n================== Arm Software Information ==================")
|
||||
# print("Arm Model: ", software_info[1]['product_version'])
|
||||
# print("Algorithm Library Version: ", software_info[1]['algorithm_info']['version'])
|
||||
# print("Control Layer Software Version: ", software_info[1]['ctrl_info']['version'])
|
||||
# print("Dynamics Version: ", software_info[1]['dynamic_info']['model_version'])
|
||||
# print("Planning Layer Software Version: ", software_info[1]['plan_info']['version'])
|
||||
# print("==============================================================\n")
|
||||
# else:
|
||||
# print("\nFailed to get arm software information, Error code: ", software_info[0], "\n")
|
||||
|
||||
print("Left: ", armleft.rm_get_current_arm_state())
|
||||
print("Left: ", armleft.rm_get_arm_all_state())
|
||||
armleft.rm_movej_p()
|
||||
# print("Right: ", armright.rm_get_current_arm_state())
|
||||
|
||||
|
||||
# 断开所有连接,销毁线程
|
||||
RoboticArm.rm_destory()
|
||||
15
realman_src/movep_canfd.py
Normal file
15
realman_src/movep_canfd.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from Robotic_Arm.rm_robot_interface import *
|
||||
import time
|
||||
|
||||
# 实例化RoboticArm类
|
||||
arm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
# 创建机械臂连接,打印连接id
|
||||
handle = arm.rm_create_robot_arm("192.168.3.18", 8080)
|
||||
print(handle.id)
|
||||
|
||||
print(arm.rm_movep_follow([-0.330512, 0.255993, -0.161205, 3.141, 0.0, -1.57]))
|
||||
time.sleep(2)
|
||||
# print(arm.rm_movep_follow([0.3, 0, 0.3, 3.14, 0, 0]))
|
||||
# time.sleep(2)
|
||||
|
||||
arm.rm_delete_robot_arm()
|
||||
0
realman_src/realman_aloha/__init__.py
Normal file
0
realman_src/realman_aloha/__init__.py
Normal file
4
realman_src/realman_aloha/shadow_camera/.gitignore
vendored
Normal file
4
realman_src/realman_aloha/shadow_camera/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
0
realman_src/realman_aloha/shadow_camera/README.md
Normal file
0
realman_src/realman_aloha/shadow_camera/README.md
Normal file
0
realman_src/realman_aloha/shadow_camera/__init__.py
Normal file
0
realman_src/realman_aloha/shadow_camera/__init__.py
Normal file
33
realman_src/realman_aloha/shadow_camera/pyproject.toml
Normal file
33
realman_src/realman_aloha/shadow_camera/pyproject.toml
Normal file
@@ -0,0 +1,33 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_camera"
|
||||
version = "0.1.0"
|
||||
description = "camera class, currently includes realsense"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
#include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9"
|
||||
numpy = ">=2.0.1"
|
||||
opencv-python = ">=4.10.0.84"
|
||||
pyrealsense2 = ">=2.55.1.6486"
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,38 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BaseCamera(metaclass=ABCMeta):
|
||||
"""摄像头基类"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_camera(self):
|
||||
"""启动相机"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop_camera(self):
|
||||
"""停止相机"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_resolution(self, resolution_width, resolution_height):
|
||||
"""设置相机彩色图像分辨率"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_frame_rate(self, fps):
|
||||
"""设置相机彩色图像帧率"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read_frame(self):
|
||||
"""读取一帧彩色图像和深度图像"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_camera_intrinsics(self):
|
||||
"""获取彩色图像和深度图像的内参"""
|
||||
pass
|
||||
Binary file not shown.
@@ -0,0 +1,38 @@
|
||||
from shadow_camera import base_camera
|
||||
import cv2
|
||||
|
||||
class OpenCVCamera(base_camera.BaseCamera):
|
||||
"""基于OpenCV的摄像头类"""
|
||||
|
||||
def __init__(self, device_id=0):
|
||||
"""初始化视频捕获
|
||||
|
||||
参数:
|
||||
device_id: 摄像头设备ID
|
||||
"""
|
||||
self.cap = cv2.VideoCapture(device_id)
|
||||
|
||||
def get_frame(self):
|
||||
"""获取当前帧
|
||||
|
||||
返回:
|
||||
frame: 当前帧的图像数据,取不到时返回None
|
||||
"""
|
||||
ret, frame = self.cap.read()
|
||||
return frame if ret else None
|
||||
|
||||
def get_frame_info(self):
|
||||
"""获取当前帧信息
|
||||
|
||||
返回:
|
||||
dict: 帧信息字典
|
||||
"""
|
||||
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
channels = int(self.cap.get(cv2.CAP_PROP_FRAME_CHANNELS))
|
||||
|
||||
return {
|
||||
'width': width,
|
||||
'height': height,
|
||||
'channels': channels
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,280 @@
|
||||
import time
|
||||
import logging
|
||||
import numpy as np
|
||||
import pyrealsense2 as rs
|
||||
import base_camera
|
||||
|
||||
# 设置日志配置
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class RealSenseCamera(base_camera.BaseCamera):
|
||||
"""Intel RealSense相机类"""
|
||||
|
||||
def __init__(self, serial_num=None, is_depth_frame=False):
|
||||
"""
|
||||
初始化相机对象
|
||||
:param serial_num: 相机序列号,默认为None
|
||||
"""
|
||||
super().__init__()
|
||||
self._color_resolution = [640, 480]
|
||||
self._depth_resolution = [640, 480]
|
||||
self._color_frames_rate = 30
|
||||
self._depth_frames_rate = 15
|
||||
self.timestamp = 0
|
||||
self.color_timestamp = 0
|
||||
self.depth_timestamp = 0
|
||||
self._colorizer = rs.colorizer()
|
||||
self._config = rs.config()
|
||||
self.is_depth_frame = is_depth_frame
|
||||
self.camera_on = False
|
||||
self.serial_num = serial_num
|
||||
|
||||
def get_serial_num(self):
|
||||
serial_num = {}
|
||||
context = rs.context()
|
||||
devices = context.query_devices() # 获取所有设备
|
||||
if len(context.devices) > 0:
|
||||
for i, device in enumerate(devices):
|
||||
serial_num[i] = device.get_info(rs.camera_info.serial_number)
|
||||
|
||||
logging.info(f"Detected serial numbers: {serial_num}")
|
||||
return serial_num
|
||||
|
||||
def _set_config(self):
|
||||
if self.serial_num is not None:
|
||||
logging.info(f"Setting device with serial number: {self.serial_num}")
|
||||
self._config.enable_device(self.serial_num)
|
||||
|
||||
self._config.enable_stream(
|
||||
rs.stream.color,
|
||||
self._color_resolution[0],
|
||||
self._color_resolution[1],
|
||||
rs.format.rgb8,
|
||||
self._color_frames_rate,
|
||||
)
|
||||
if self.is_depth_frame:
|
||||
self._config.enable_stream(
|
||||
rs.stream.depth,
|
||||
self._depth_resolution[0],
|
||||
self._depth_resolution[1],
|
||||
rs.format.z16,
|
||||
self._depth_frames_rate,
|
||||
)
|
||||
|
||||
def start_camera(self):
|
||||
"""
|
||||
启动相机并获取内参信息,如果后续调用帧对齐,则内参均为彩色内参
|
||||
"""
|
||||
self._pipeline = rs.pipeline()
|
||||
if self.is_depth_frame:
|
||||
self.point_cloud = rs.pointcloud()
|
||||
self._align = rs.align(rs.stream.color)
|
||||
self._set_config()
|
||||
|
||||
self.profile = self._pipeline.start(self._config)
|
||||
|
||||
if self.is_depth_frame:
|
||||
self._depth_intrinsics = (
|
||||
self.profile.get_stream(rs.stream.depth)
|
||||
.as_video_stream_profile()
|
||||
.get_intrinsics()
|
||||
)
|
||||
|
||||
self._color_intrinsics = (
|
||||
self.profile.get_stream(rs.stream.color)
|
||||
.as_video_stream_profile()
|
||||
.get_intrinsics()
|
||||
)
|
||||
self.camera_on = True
|
||||
logging.info("Camera started successfully")
|
||||
logging.info(
|
||||
f"Camera started with color resolution: {self._color_resolution}, depth resolution: {self._depth_resolution}"
|
||||
)
|
||||
logging.info(
|
||||
f"Color FPS: {self._color_frames_rate}, Depth FPS: {self._depth_frames_rate}"
|
||||
)
|
||||
|
||||
def stop_camera(self):
|
||||
"""
|
||||
停止相机
|
||||
"""
|
||||
self._pipeline.stop()
|
||||
self.camera_on = False
|
||||
logging.info("Camera stopped")
|
||||
|
||||
def set_resolution(self, color_resolution, depth_resolution):
|
||||
self._color_resolution = color_resolution
|
||||
self._depth_resolution = depth_resolution
|
||||
logging.info(
|
||||
"Optional color resolution:"
|
||||
"[320, 180] [320, 240] [424, 240] [640, 360] [640, 480]"
|
||||
"[848, 480] [960, 540] [1280, 720] [1920, 1080]"
|
||||
)
|
||||
logging.info(
|
||||
"Optional depth resolution:"
|
||||
"[256, 144] [424, 240] [480, 270] [640, 360] [640, 400]"
|
||||
"[640, 480] [848, 100] [848, 480] [1280, 720] [1280, 800]"
|
||||
)
|
||||
logging.info(f"Set color resolution to: {color_resolution}")
|
||||
logging.info(f"Set depth resolution to: {depth_resolution}")
|
||||
|
||||
def set_frame_rate(self, color_fps, depth_fps):
|
||||
self._color_frames_rate = color_fps
|
||||
self._depth_frames_rate = depth_fps
|
||||
logging.info("Optional color fps: 6 15 30 60 ")
|
||||
logging.info("Optional depth fps: 6 15 30 60 90 100 300")
|
||||
logging.info(f"Set color FPS to: {color_fps}")
|
||||
logging.info(f"Set depth FPS to: {depth_fps}")
|
||||
|
||||
# TODO: 调节白平衡进行补偿
|
||||
# def set_exposure(self, exposure):
|
||||
|
||||
def read_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
|
||||
"""
|
||||
读取一帧彩色图像和深度图像
|
||||
:return: 彩色图像和深度图像的NumPy数组
|
||||
"""
|
||||
while not self.camera_on:
|
||||
time.sleep(0.5)
|
||||
color_image = None
|
||||
depth_image = None
|
||||
colorized_depth = None
|
||||
point_cloud = None
|
||||
try:
|
||||
frames = self._pipeline.wait_for_frames()
|
||||
if is_color:
|
||||
color_frame = frames.get_color_frame()
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
else:
|
||||
color_image = None
|
||||
|
||||
if is_depth:
|
||||
depth_frame = frames.get_depth_frame()
|
||||
depth_image = np.asanyarray(depth_frame.get_data())
|
||||
else:
|
||||
depth_image = None
|
||||
|
||||
colorized_depth = (
|
||||
np.asanyarray(self._colorizer.colorize(depth_frame).get_data())
|
||||
if is_colorized_depth
|
||||
else None
|
||||
)
|
||||
point_cloud = (
|
||||
np.asanyarray(self.point_cloud.calculate(depth_frame).get_vertices())
|
||||
if is_point_cloud
|
||||
else None
|
||||
)
|
||||
# 获取时间戳单位为ms,对齐后color时间戳 > depth = aligned,选择color
|
||||
self.color_timestamp = color_frame.get_timestamp()
|
||||
if self.is_depth_frame:
|
||||
self.depth_timestamp = depth_frame.get_timestamp()
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(e)
|
||||
if "Frame didn't arrive within 5000" in str(e):
|
||||
logging.warning("Frame didn't arrive within 5000ms, resetting device")
|
||||
self.stop_camera()
|
||||
self.start_camera()
|
||||
|
||||
return color_image, depth_image, colorized_depth, point_cloud
|
||||
|
||||
def read_align_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
|
||||
"""
|
||||
读取一帧对齐的彩色图像和深度图像
|
||||
:return: 彩色图像和深度图像的NumPy数组
|
||||
"""
|
||||
while not self.camera_on:
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
frames = self._pipeline.wait_for_frames()
|
||||
aligned_frames = self._align.process(frames)
|
||||
aligned_color_frame = aligned_frames.get_color_frame()
|
||||
self._aligned_depth_frame = aligned_frames.get_depth_frame()
|
||||
|
||||
color_image = np.asanyarray(aligned_color_frame.get_data())
|
||||
depth_image = np.asanyarray(self._aligned_depth_frame.get_data())
|
||||
colorized_depth = (
|
||||
np.asanyarray(
|
||||
self._colorizer.colorize(self._aligned_depth_frame).get_data()
|
||||
)
|
||||
if is_colorized_depth
|
||||
else None
|
||||
)
|
||||
|
||||
if is_point_cloud:
|
||||
points = self.point_cloud.calculate(self._aligned_depth_frame)
|
||||
# 将元组数据转换为 NumPy 数组
|
||||
point_cloud = np.array(
|
||||
[[point[0], point[1], point[2]] for point in points.get_vertices()]
|
||||
)
|
||||
else:
|
||||
point_cloud = None
|
||||
|
||||
# 获取时间戳单位为ms,对齐后color时间戳 > depth = aligned,选择color
|
||||
self.timestamp = aligned_color_frame.get_timestamp()
|
||||
|
||||
return color_image, depth_image, colorized_depth, point_cloud
|
||||
|
||||
except Exception as e:
|
||||
if "Frame didn't arrive within 5000" in str(e):
|
||||
logging.warning("Frame didn't arrive within 5000ms, resetting device")
|
||||
self.stop_camera()
|
||||
self.start_camera()
|
||||
# device = self.profile.get_device()
|
||||
# device.hardware_reset()
|
||||
|
||||
def get_camera_intrinsics(self):
|
||||
"""
|
||||
获取彩色图像和深度图像的内参信息
|
||||
:return: 彩色图像和深度图像的内参信息
|
||||
"""
|
||||
# 宽高:.width, .height; 焦距:.fx, .fy; 像素坐标:.ppx, .ppy; 畸变系数:.coeffs
|
||||
logging.info("Getting camera intrinsics")
|
||||
logging.info(
|
||||
"Width and height: .width, .height; Focal length: .fx, .fy; Pixel coordinates: .ppx, .ppy; Distortion coefficient: .coeffs"
|
||||
)
|
||||
return self._color_intrinsics, self._depth_intrinsics
|
||||
|
||||
def get_3d_camera_coordinate(self, depth_pixel, align=False):
|
||||
"""
|
||||
获取深度相机坐标系下的三维坐标
|
||||
:param depth_pixel:深度像素坐标
|
||||
:param align: 是否对齐
|
||||
|
||||
:return: 深度值和相机坐标
|
||||
"""
|
||||
if not hasattr(self, "_aligned_depth_frame"):
|
||||
raise AttributeError(
|
||||
"Aligned depth frame not set. Call read_align_frame() first."
|
||||
)
|
||||
|
||||
distance = self._aligned_depth_frame.get_distance(
|
||||
depth_pixel[0], depth_pixel[1]
|
||||
)
|
||||
intrinsics = self._color_intrinsics if align else self._depth_intrinsics
|
||||
camera_coordinate = rs.rs2_deproject_pixel_to_point(
|
||||
intrinsics, depth_pixel, distance
|
||||
)
|
||||
return distance, camera_coordinate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
camera = RealSenseCamera(is_depth_frame=False)
|
||||
camera.get_serial_num()
|
||||
camera.start_camera()
|
||||
# camera.set_frame_rate(60, 60)
|
||||
color_image, depth_image, colorized_depth, point_cloud = camera.read_frame()
|
||||
camera.stop_camera()
|
||||
logging.info(f"Color image shape: {color_image.shape}")
|
||||
# logging.info(f"Depth image shape: {depth_image.shape}")
|
||||
# logging.info(f"Colorized depth image shape: {colorized_depth.shape}")
|
||||
# logging.info(f"Point cloud shape: {point_cloud.shape}")
|
||||
logging.info(f"Color timestamp: {camera.timestamp}")
|
||||
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
|
||||
logging.info(f"Color timestamp: {camera.color_timestamp}")
|
||||
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
|
||||
logging.info("Test passed")
|
||||
@@ -0,0 +1,101 @@
|
||||
import pyrealsense2 as rs
|
||||
import numpy as np
|
||||
import h5py
|
||||
import time
|
||||
import threading
|
||||
import keyboard # 用于监听键盘输入
|
||||
|
||||
# 全局变量
|
||||
is_recording = False # 标志位,控制录制状态
|
||||
color_images = [] # 存储彩色图像
|
||||
depth_images = [] # 存储深度图像
|
||||
timestamps = [] # 存储时间戳
|
||||
|
||||
# 配置D435相机
|
||||
def configure_camera():
|
||||
pipeline = rs.pipeline()
|
||||
config = rs.config()
|
||||
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) # 彩色图像流
|
||||
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30) # 深度图像流
|
||||
pipeline.start(config)
|
||||
return pipeline
|
||||
|
||||
# 监听键盘输入,控制录制状态
|
||||
def listen_for_keyboard():
|
||||
global is_recording
|
||||
while True:
|
||||
if keyboard.is_pressed('s'): # 按下 's' 开始录制
|
||||
is_recording = True
|
||||
print("Recording started.")
|
||||
time.sleep(0.5) # 防止重复触发
|
||||
elif keyboard.is_pressed('q'): # 按下 'q' 停止录制
|
||||
is_recording = False
|
||||
print("Recording stopped.")
|
||||
time.sleep(0.5) # 防止重复触发
|
||||
elif keyboard.is_pressed('e'): # 按下 'e' 退出程序
|
||||
print("Exiting program.")
|
||||
exit()
|
||||
time.sleep(0.1)
|
||||
|
||||
# 采集图像数据
|
||||
def capture_frames(pipeline):
|
||||
global is_recording, color_images, depth_images, timestamps
|
||||
try:
|
||||
while True:
|
||||
if is_recording:
|
||||
frames = pipeline.wait_for_frames()
|
||||
color_frame = frames.get_color_frame()
|
||||
depth_frame = frames.get_depth_frame()
|
||||
|
||||
if not color_frame or not depth_frame:
|
||||
continue
|
||||
|
||||
# 获取当前时间戳
|
||||
timestamp = time.time()
|
||||
|
||||
# 将图像转换为numpy数组
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
depth_image = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
# 存储数据
|
||||
color_images.append(color_image)
|
||||
depth_images.append(depth_image)
|
||||
timestamps.append(timestamp)
|
||||
|
||||
print(f"Captured frame at {timestamp}")
|
||||
|
||||
else:
|
||||
time.sleep(0.1) # 如果未录制,等待一段时间
|
||||
|
||||
finally:
|
||||
pipeline.stop()
|
||||
|
||||
# 保存为HDF5文件
|
||||
def save_to_hdf5(color_images, depth_images, timestamps, filename="output.h5"):
|
||||
with h5py.File(filename, "w") as f:
|
||||
f.create_dataset("color_images", data=np.array(color_images), compression="gzip")
|
||||
f.create_dataset("depth_images", data=np.array(depth_images), compression="gzip")
|
||||
f.create_dataset("timestamps", data=np.array(timestamps), compression="gzip")
|
||||
print(f"Data saved to {filename}")
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
global is_recording, color_images, depth_images, timestamps
|
||||
|
||||
# 启动键盘监听线程
|
||||
keyboard_thread = threading.Thread(target=listen_for_keyboard)
|
||||
keyboard_thread.daemon = True
|
||||
keyboard_thread.start()
|
||||
|
||||
# 配置相机
|
||||
pipeline = configure_camera()
|
||||
|
||||
# 开始采集图像
|
||||
capture_frames(pipeline)
|
||||
|
||||
# 录制结束后保存数据
|
||||
if color_images and depth_images and timestamps:
|
||||
save_to_hdf5(color_images, depth_images, timestamps, "mobile_aloha_data.h5")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
152
realman_src/realman_aloha/shadow_camera/test/test_camera.py
Normal file
152
realman_src/realman_aloha/shadow_camera/test/test_camera.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
from os import path
|
||||
import pyrealsense2 as rs
|
||||
from shadow_camera import realsense
|
||||
import logging
|
||||
|
||||
|
||||
|
||||
def test_camera():
|
||||
camera = realsense.RealSenseCamera('241122071186')
|
||||
camera.start_camera()
|
||||
|
||||
while True:
|
||||
# result = camera.read_align_frame()
|
||||
# if result is None:
|
||||
# print('is None')
|
||||
# continue
|
||||
# start_time = time.time()
|
||||
color_image, depth_image, colorized_depth, vtx = camera.read_frame()
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
print(f"color_image: {color_image.shape}")
|
||||
# print(f"Time: {end_time - start_time}")
|
||||
cv2.imshow("bgr_image", color_image)
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
camera.stop_camera()
|
||||
|
||||
|
||||
def test_get_serial_num():
|
||||
camera = realsense.RealSenseCamera()
|
||||
device = camera.get_serial_num()
|
||||
|
||||
|
||||
class CameraCapture:
|
||||
def __init__(self, camera_serial_num=None, save_dir="./save"):
|
||||
self._camera_serial_num = camera_serial_num
|
||||
self._color_save_dir = path.join(save_dir, "color")
|
||||
self._depth_save_dir = path.join(save_dir, "depth")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
os.makedirs(self._color_save_dir, exist_ok=True)
|
||||
os.makedirs(self._depth_save_dir, exist_ok=True)
|
||||
|
||||
def get_serial_num(self):
|
||||
self._camera_serial_num = {}
|
||||
camera_names = ["left", "right", "head", "table"]
|
||||
context = rs.context()
|
||||
devices = context.query_devices() # 获取所有设备
|
||||
if len(context.devices) > 0:
|
||||
for i, device in enumerate(devices):
|
||||
self._camera_serial_num[camera_names[i]] = device.get_info(
|
||||
rs.camera_info.serial_number
|
||||
)
|
||||
print(self._camera_serial_num)
|
||||
|
||||
return self._camera_serial_num
|
||||
|
||||
def start_camera(self):
|
||||
if self._camera_serial_num is None:
|
||||
self.get_serial_num()
|
||||
self._camera_left = realsense.RealSenseCamera(self._camera_serial_num["left"])
|
||||
self._camera_right = realsense.RealSenseCamera(self._camera_serial_num["right"])
|
||||
self._camera_head = realsense.RealSenseCamera(self._camera_serial_num["head"])
|
||||
|
||||
self._camera_left.start_camera()
|
||||
self._camera_right.start_camera()
|
||||
self._camera_head.start_camera()
|
||||
|
||||
def stop_camera(self):
|
||||
self._camera_left.stop_camera()
|
||||
self._camera_right.stop_camera()
|
||||
self._camera_head.stop_camera()
|
||||
|
||||
def _save_datas(self, timestamp, color_image, depth_image, camera_name):
|
||||
color_filename = path.join(
|
||||
self._color_save_dir, f"{timestamp}" + camera_name + ".jpg"
|
||||
)
|
||||
depth_filename = path.join(
|
||||
self._depth_save_dir, f"{timestamp}" + camera_name + ".png"
|
||||
)
|
||||
cv2.imwrite(color_filename, color_image)
|
||||
cv2.imwrite(depth_filename, depth_image)
|
||||
|
||||
def capture_images(self):
|
||||
while True:
|
||||
(
|
||||
color_image_left,
|
||||
depth_image_left,
|
||||
_,
|
||||
_,
|
||||
) = self._camera_left.read_align_frame()
|
||||
(
|
||||
color_image_right,
|
||||
depth_image_right,
|
||||
_,
|
||||
_,
|
||||
) = self._camera_right.read_align_frame()
|
||||
(
|
||||
color_image_head,
|
||||
depth_image_head,
|
||||
_,
|
||||
point_cloud3,
|
||||
) = self._camera_head.read_align_frame()
|
||||
|
||||
bgr_color_image_left = cv2.cvtColor(color_image_left, cv2.COLOR_RGB2BGR)
|
||||
bgr_color_image_right = cv2.cvtColor(color_image_right, cv2.COLOR_RGB2BGR)
|
||||
bgr_color_image_head = cv2.cvtColor(color_image_head, cv2.COLOR_RGB2BGR)
|
||||
|
||||
timestamp = time.time() * 1000
|
||||
|
||||
cv2.imshow("Camera left", bgr_color_image_left)
|
||||
cv2.imshow("Camera right", bgr_color_image_right)
|
||||
cv2.imshow("Camera head", bgr_color_image_head)
|
||||
|
||||
# self._save_datas(
|
||||
# timestamp, bgr_color_image_left, depth_image_left, "left"
|
||||
# )
|
||||
# self._save_datas(
|
||||
# timestamp, bgr_color_image_right, depth_image_right, "right"
|
||||
# )
|
||||
# self._save_datas(
|
||||
# timestamp, bgr_color_image_head, depth_image_head, "head"
|
||||
# )
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#test_camera()
|
||||
test_get_serial_num()
|
||||
"""
|
||||
输入相机序列号制定左右相机:
|
||||
dict:{'left': '241222075132', 'right': '242322076532', 'head': '242322076532'}
|
||||
保存路径:
|
||||
str:./save
|
||||
输入为空,自动分配相机序列号(不指定左、右、头部),保存路径为./save
|
||||
"""
|
||||
|
||||
# capture = CameraCapture()
|
||||
# capture.get_serial_num()
|
||||
# test_get_serial_num()
|
||||
|
||||
# capture.start_camera()
|
||||
# capture.capture_images()
|
||||
# capture.stop_camera()
|
||||
@@ -0,0 +1,71 @@
|
||||
import pytest
|
||||
import pyrealsense2 as rs
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
|
||||
|
||||
class TestRealSenseCamera:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_camera(self):
|
||||
self.camera = RealSenseCamera()
|
||||
|
||||
def test_get_serial_num(self):
|
||||
serial_nums = self.camera.get_serial_num()
|
||||
assert isinstance(serial_nums, dict)
|
||||
assert len(serial_nums) > 0
|
||||
|
||||
def test_start_stop_camera(self):
|
||||
self.camera.start_camera()
|
||||
assert self.camera.camera_on is True
|
||||
self.camera.stop_camera()
|
||||
assert self.camera.camera_on is False
|
||||
|
||||
def test_set_resolution(self):
|
||||
color_resolution = [1280, 720]
|
||||
depth_resolution = [1280, 720]
|
||||
self.camera.set_resolution(color_resolution, depth_resolution)
|
||||
assert self.camera._color_resolution == color_resolution
|
||||
assert self.camera._depth_resolution == depth_resolution
|
||||
|
||||
def test_set_frame_rate(self):
|
||||
color_fps = 60
|
||||
depth_fps = 60
|
||||
self.camera.set_frame_rate(color_fps, depth_fps)
|
||||
assert self.camera._color_frames_rate == color_fps
|
||||
assert self.camera._depth_frames_rate == depth_fps
|
||||
|
||||
def test_read_frame(self):
|
||||
self.camera.start_camera()
|
||||
color_image, depth_image, colorized_depth, point_cloud = (
|
||||
self.camera.read_frame()
|
||||
)
|
||||
assert color_image is not None
|
||||
assert depth_image is not None
|
||||
self.camera.stop_camera()
|
||||
|
||||
def test_read_align_frame(self):
|
||||
self.camera.start_camera()
|
||||
color_image, depth_image, colorized_depth, point_cloud = (
|
||||
self.camera.read_align_frame()
|
||||
)
|
||||
assert color_image is not None
|
||||
assert depth_image is not None
|
||||
self.camera.stop_camera()
|
||||
|
||||
def test_get_camera_intrinsics(self):
|
||||
self.camera.start_camera()
|
||||
color_intrinsics, depth_intrinsics = self.camera.get_camera_intrinsics()
|
||||
assert color_intrinsics is not None
|
||||
assert depth_intrinsics is not None
|
||||
self.camera.stop_camera()
|
||||
|
||||
def test_get_3d_camera_coordinate(self):
|
||||
self.camera.start_camera()
|
||||
# 先调用 read_align_frame 方法以确保 _aligned_depth_frame 被设置
|
||||
self.camera.read_align_frame()
|
||||
depth_pixel = [320, 240]
|
||||
distance, camera_coordinate = self.camera.get_3d_camera_coordinate(
|
||||
depth_pixel, align=True
|
||||
)
|
||||
assert distance > 0
|
||||
assert len(camera_coordinate) == 3
|
||||
self.camera.stop_camera()
|
||||
10
realman_src/realman_aloha/shadow_rm_act/.gitignore
vendored
Normal file
10
realman_src/realman_aloha/shadow_rm_act/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
__pycache__/
|
||||
build/
|
||||
devel/
|
||||
dist/
|
||||
data/
|
||||
.catkin_workspace
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
.vscode/
|
||||
89
realman_src/realman_aloha/shadow_rm_act/README.md
Normal file
89
realman_src/realman_aloha/shadow_rm_act/README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# ACT: Action Chunking with Transformers
|
||||
|
||||
### *New*: [ACT tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing)
|
||||
TL;DR: if your ACT policy is jerky or pauses in the middle of an episode, just train for longer! Success rate and smoothness can improve way after loss plateaus.
|
||||
|
||||
#### Project Website: https://tonyzhaozh.github.io/aloha/
|
||||
|
||||
This repo contains the implementation of ACT, together with 2 simulated environments:
|
||||
Transfer Cube and Bimanual Insertion. You can train and evaluate ACT in sim or real.
|
||||
For real, you would also need to install [ALOHA](https://github.com/tonyzhaozh/aloha).
|
||||
|
||||
### Updates:
|
||||
You can find all scripted/human demo for simulated environments [here](https://drive.google.com/drive/folders/1gPR03v05S1xiInoVJn7G7VJ9pDCnxq9O?usp=share_link).
|
||||
|
||||
|
||||
### Repo Structure
|
||||
- ``imitate_episodes.py`` Train and Evaluate ACT
|
||||
- ``policy.py`` An adaptor for ACT policy
|
||||
- ``detr`` Model definitions of ACT, modified from DETR
|
||||
- ``sim_env.py`` Mujoco + DM_Control environments with joint space control
|
||||
- ``ee_sim_env.py`` Mujoco + DM_Control environments with EE space control
|
||||
- ``scripted_policy.py`` Scripted policies for sim environments
|
||||
- ``constants.py`` Constants shared across files
|
||||
- ``utils.py`` Utils such as data loading and helper functions
|
||||
- ``visualize_episodes.py`` Save videos from a .hdf5 dataset
|
||||
|
||||
|
||||
### Installation
|
||||
|
||||
conda create -n aloha python=3.8.10
|
||||
conda activate aloha
|
||||
pip install torchvision
|
||||
pip install torch
|
||||
pip install pyquaternion
|
||||
pip install pyyaml
|
||||
pip install rospkg
|
||||
pip install pexpect
|
||||
pip install mujoco==2.3.7
|
||||
pip install dm_control==1.0.14
|
||||
pip install opencv-python
|
||||
pip install matplotlib
|
||||
pip install einops
|
||||
pip install packaging
|
||||
pip install h5py
|
||||
pip install ipython
|
||||
cd act/detr && pip install -e .
|
||||
|
||||
### Example Usages
|
||||
|
||||
To set up a new terminal, run:
|
||||
|
||||
conda activate aloha
|
||||
cd <path to act repo>
|
||||
|
||||
### Simulated experiments
|
||||
|
||||
We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``.
|
||||
To generated 50 episodes of scripted data, run:
|
||||
|
||||
python3 record_sim_episodes.py \
|
||||
--task_name sim_transfer_cube_scripted \
|
||||
--dataset_dir <data save dir> \
|
||||
--num_episodes 50
|
||||
|
||||
To can add the flag ``--onscreen_render`` to see real-time rendering.
|
||||
To visualize the episode after it is collected, run
|
||||
|
||||
python3 visualize_episodes.py --dataset_dir <data save dir> --episode_idx 0
|
||||
|
||||
To train ACT:
|
||||
|
||||
# Transfer Cube task
|
||||
python3 imitate_episodes.py \
|
||||
--task_name sim_transfer_cube_scripted \
|
||||
--ckpt_dir <ckpt dir> \
|
||||
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \
|
||||
--num_epochs 2000 --lr 1e-5 \
|
||||
--seed 0
|
||||
|
||||
|
||||
To evaluate the policy, run the same command but add ``--eval``. This loads the best validation checkpoint.
|
||||
The success rate should be around 90% for transfer cube, and around 50% for insertion.
|
||||
To enable temporal ensembling, add flag ``--temporal_agg``.
|
||||
Videos will be saved to ``<ckpt_dir>`` for each rollout.
|
||||
You can also add ``--onscreen_render`` to see real-time rendering during evaluation.
|
||||
|
||||
For real-world data where things can be harder to model, train for at least 5000 epochs or 3-4 times the length after the loss has plateaued.
|
||||
Please refer to [tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing) for more info.
|
||||
|
||||
74
realman_src/realman_aloha/shadow_rm_act/config/config.yaml
Normal file
74
realman_src/realman_aloha/shadow_rm_act/config/config.yaml
Normal file
@@ -0,0 +1,74 @@
|
||||
robot_env: {
|
||||
# TODO change the path to the correct one
|
||||
rm_left_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_left_arm.yaml',
|
||||
rm_right_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_right_arm.yaml',
|
||||
arm_axis: 6,
|
||||
head_camera: '215222076892',
|
||||
bottom_camera: '215222076981',
|
||||
left_camera: '152122078151',
|
||||
right_camera: '152122073489',
|
||||
# init_left_arm_angle: [0.226, 21.180, 91.304, -0.515, 67.486, 2.374, 0.9],
|
||||
# init_right_arm_angle: [-1.056, 33.057, 84.376, -0.204, 66.357, -3.236, 0.9]
|
||||
init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
|
||||
init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
|
||||
}
|
||||
dataset_dir: '/home/rm/aloha/shadow_rm_aloha/data/dataset/20250103'
|
||||
checkpoint_dir: '/home/rm/aloha/shadow_rm_act/data'
|
||||
# checkpoint_name: 'policy_best.ckpt'
|
||||
checkpoint_name: 'policy_9500.ckpt'
|
||||
state_dim: 14
|
||||
save_episode: True
|
||||
num_rollouts: 50 #训练期间要收集的 rollout(轨迹)数量
|
||||
real_robot: True
|
||||
policy_class: 'ACT'
|
||||
onscreen_render: False
|
||||
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right']
|
||||
episode_len: 300 #episode 的最大长度(时间步数)。
|
||||
task_name: 'aloha_01_11.28'
|
||||
temporal_agg: False #是否使用时间聚合
|
||||
batch_size: 8 #训练期间每批的样本数。
|
||||
seed: 1000 #随机种子。
|
||||
chunk_size: 30 #用于处理序列的块大小
|
||||
eval_every: 1 #每隔 eval_every 步评估一次模型。
|
||||
num_steps: 10000 #训练的总步数。
|
||||
validate_every: 1 #每隔 validate_every 步验证一次模型。
|
||||
save_every: 500 #每隔 save_every 步保存一次检查点。
|
||||
load_pretrain: False #是否加载预训练模型。
|
||||
resume_ckpt_path:
|
||||
name_filter: # TODO
|
||||
skip_mirrored_data: False #是否跳过镜像数据(例如用于基于对称性的数据增强)。
|
||||
stats_dir:
|
||||
sample_weights:
|
||||
train_ratio: 0.8 #用于训练的数据比例(其余数据用于验证)
|
||||
|
||||
policy_config: {
|
||||
hidden_dim: 512, # Size of the embeddings (dimension of the transformer)
|
||||
state_dim: 14, # Dimension of the state
|
||||
position_embedding: 'sine', # ('sine', 'learned').Type of positional embedding to use on top of the image features
|
||||
lr_backbone: 1.0e-5,
|
||||
masks: False, # If true, the model masks the non-visible pixels
|
||||
backbone: 'resnet18',
|
||||
dilation: False, # If true, we replace stride with dilation in the last convolutional block (DC5)
|
||||
dropout: 0.1, # Dropout applied in the transformer
|
||||
nheads: 8,
|
||||
dim_feedforward: 3200, # Intermediate size of the feedforward layers in the transformer blocks
|
||||
enc_layers: 4, # Number of encoding layers in the transformer
|
||||
dec_layers: 7, # Number of decoding layers in the transformer
|
||||
pre_norm: False, # If true, apply LayerNorm to the input instead of the output of the MultiheadAttention and FeedForward
|
||||
num_queries: 30,
|
||||
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right'],
|
||||
vq: False,
|
||||
vq_class: none,
|
||||
vq_dim: 64,
|
||||
action_dim: 14,
|
||||
no_encoder: False,
|
||||
lr: 1.0e-5,
|
||||
weight_decay: 1.0e-4,
|
||||
kl_weight: 10,
|
||||
|
||||
# lr_drop: 200,
|
||||
# clip_max_norm: 0.1,
|
||||
}
|
||||
|
||||
|
||||
|
||||
267
realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py
Normal file
267
realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import numpy as np
|
||||
import collections
|
||||
import os
|
||||
|
||||
from constants import DT, XML_DIR, START_ARM_POSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_CLOSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||
|
||||
from src.shadow_act.utils.utils import sample_box_pose, sample_insertion_pose
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def make_ee_sim_env(task_name):
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with end-effector control.
|
||||
Action space: [left_arm_pose (7), # position and quaternion for end effector
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_pose (7), # position and quaternion for end effector
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_transfer_cube.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = TransferCubeEETask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
elif 'sim_insertion' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_insertion.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = InsertionEETask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
class BimanualViperXEETask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
a_len = len(action) // 2
|
||||
action_left = action[:a_len]
|
||||
action_right = action[a_len:]
|
||||
|
||||
# set mocap position and quat
|
||||
# left
|
||||
np.copyto(physics.data.mocap_pos[0], action_left[:3])
|
||||
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], action_right[:3])
|
||||
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
|
||||
|
||||
# set gripper
|
||||
g_left_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_left[7])
|
||||
g_right_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_right[7])
|
||||
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
|
||||
|
||||
def initialize_robots(self, physics):
|
||||
# reset joint position
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
|
||||
# reset mocap to align with end effector
|
||||
# to obtain these numbers:
|
||||
# (1) make an ee_sim env and reset to the same start_pose
|
||||
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
|
||||
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
|
||||
# repeat the same for right side
|
||||
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
|
||||
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
|
||||
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
|
||||
|
||||
# reset gripper control
|
||||
close_gripper_control = np.array([
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
])
|
||||
np.copyto(physics.data.ctrl, close_gripper_control)
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
# note: it is important to do .copy()
|
||||
obs = collections.OrderedDict()
|
||||
obs['qpos'] = self.get_qpos(physics)
|
||||
obs['qvel'] = self.get_qvel(physics)
|
||||
obs['env_state'] = self.get_env_state(physics)
|
||||
obs['images'] = dict()
|
||||
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
|
||||
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
|
||||
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
|
||||
# used in scripted policy to obtain starting pose
|
||||
obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy()
|
||||
obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy()
|
||||
|
||||
# used when replaying joint trajectory
|
||||
obs['gripper_ctrl'] = physics.data.ctrl.copy()
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeEETask(BimanualViperXEETask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize box position
|
||||
cube_pose = sample_box_pose()
|
||||
box_start_idx = physics.model.name2id('red_box_joint', 'joint')
|
||||
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionEETask(BimanualViperXEETask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize peg and socket position
|
||||
peg_pose, socket_pose = sample_insertion_pose()
|
||||
id2index = lambda j_id: 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
|
||||
|
||||
peg_start_id = physics.model.name2id('red_peg_joint', 'joint')
|
||||
peg_start_idx = id2index(peg_start_id)
|
||||
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
socket_start_id = physics.model.name2id('blue_socket_joint', 'joint')
|
||||
socket_start_idx = id2index(socket_start_id)
|
||||
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||
("socket-2", "table") in all_contact_pairs or \
|
||||
("socket-3", "table") in all_contact_pairs or \
|
||||
("socket-4", "table") in all_contact_pairs
|
||||
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||
("red_peg", "socket-2") in all_contact_pairs or \
|
||||
("red_peg", "socket-3") in all_contact_pairs or \
|
||||
("red_peg", "socket-4") in all_contact_pairs
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
36
realman_src/realman_aloha/shadow_rm_act/pyproject.toml
Normal file
36
realman_src/realman_aloha/shadow_rm_act/pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_act"
|
||||
version = "0.1.0"
|
||||
description = "Embodied data, ACT and other methods; training and verification function packages"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9"
|
||||
wandb = ">=0.18.0"
|
||||
einops = ">=0.8.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
189
realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py
Normal file
189
realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
import h5py
|
||||
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
from sim_env import make_sim_env, BOX_POSE
|
||||
from scripted_policy import PickAndTransferPolicy, InsertionPolicy
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Generate demonstration data in simulation.
|
||||
First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory.
|
||||
Replace the gripper joint positions with the commanded joint position.
|
||||
Replay this joint trajectory (as action sequence) in sim_env, and record all observations.
|
||||
Save this episode of data, and continue to next episode of data collection.
|
||||
"""
|
||||
|
||||
task_name = args['task_name']
|
||||
dataset_dir = args['dataset_dir']
|
||||
num_episodes = args['num_episodes']
|
||||
onscreen_render = args['onscreen_render']
|
||||
inject_noise = False
|
||||
render_cam_name = 'angle'
|
||||
|
||||
if not os.path.isdir(dataset_dir):
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
camera_names = SIM_TASK_CONFIGS[task_name]['camera_names']
|
||||
if task_name == 'sim_transfer_cube_scripted':
|
||||
policy_cls = PickAndTransferPolicy
|
||||
elif task_name == 'sim_insertion_scripted':
|
||||
policy_cls = InsertionPolicy
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
success = []
|
||||
for episode_idx in range(num_episodes):
|
||||
print(f'{episode_idx=}')
|
||||
print('Rollout out EE space scripted policy')
|
||||
# setup the environment
|
||||
env = make_ee_sim_env(task_name)
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
policy = policy_cls(inject_noise)
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||
plt.ion()
|
||||
for step in range(episode_len):
|
||||
action = policy(ts)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||
plt.pause(0.002)
|
||||
plt.close()
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||
episode_max_reward = np.max([ts.reward for ts in episode[1:]])
|
||||
if episode_max_reward == env.task.max_reward:
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
joint_traj = [ts.observation['qpos'] for ts in episode]
|
||||
# replace gripper pose with gripper control
|
||||
gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode]
|
||||
for joint, ctrl in zip(joint_traj, gripper_ctrl_traj):
|
||||
left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0])
|
||||
right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2])
|
||||
joint[6] = left_ctrl
|
||||
joint[6+7] = right_ctrl
|
||||
|
||||
subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0
|
||||
|
||||
# clear unused variables
|
||||
del env
|
||||
del episode
|
||||
del policy
|
||||
|
||||
# setup the environment
|
||||
print('Replaying joint commands')
|
||||
env = make_sim_env(task_name)
|
||||
BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env
|
||||
ts = env.reset()
|
||||
|
||||
episode_replay = [ts]
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||
plt.ion()
|
||||
for t in range(len(joint_traj)): # note: this will increase episode length by 1
|
||||
action = joint_traj[t]
|
||||
ts = env.step(action)
|
||||
episode_replay.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||
plt.pause(0.02)
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode_replay[1:]])
|
||||
episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]])
|
||||
if episode_max_reward == env.task.max_reward:
|
||||
success.append(1)
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
success.append(0)
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
plt.close()
|
||||
|
||||
"""
|
||||
For each timestep:
|
||||
observations
|
||||
- images
|
||||
- each_cam_name (480, 640, 3) 'uint8'
|
||||
- qpos (14,) 'float64'
|
||||
- qvel (14,) 'float64'
|
||||
|
||||
action (14,) 'float64'
|
||||
"""
|
||||
|
||||
data_dict = {
|
||||
'/observations/qpos': [],
|
||||
'/observations/qvel': [],
|
||||
'/action': [],
|
||||
}
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'] = []
|
||||
|
||||
# because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps
|
||||
# truncate here to be consistent
|
||||
joint_traj = joint_traj[:-1]
|
||||
episode_replay = episode_replay[:-1]
|
||||
|
||||
# len(joint_traj) i.e. actions: max_timesteps
|
||||
# len(episode_replay) i.e. time steps: max_timesteps + 1
|
||||
max_timesteps = len(joint_traj)
|
||||
while joint_traj:
|
||||
action = joint_traj.pop(0)
|
||||
ts = episode_replay.pop(0)
|
||||
data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||
data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||
data_dict['/action'].append(action)
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||
|
||||
# HDF5
|
||||
t0 = time.time()
|
||||
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
|
||||
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
|
||||
root.attrs['sim'] = True
|
||||
obs = root.create_group('observations')
|
||||
image = obs.create_group('images')
|
||||
for cam_name in camera_names:
|
||||
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3), )
|
||||
# compression='gzip',compression_opts=2,)
|
||||
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
|
||||
qpos = obs.create_dataset('qpos', (max_timesteps, 14))
|
||||
qvel = obs.create_dataset('qvel', (max_timesteps, 14))
|
||||
action = root.create_dataset('action', (max_timesteps, 14))
|
||||
|
||||
for name, array in data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f'Saving: {time.time() - t0:.1f} secs\n')
|
||||
|
||||
print(f'Saved to {dataset_dir}')
|
||||
print(f'Success: {np.sum(success)} / {len(success)}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True)
|
||||
parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False)
|
||||
parser.add_argument('--onscreen_render', action='store_true')
|
||||
|
||||
main(vars(parser.parse_args()))
|
||||
|
||||
194
realman_src/realman_aloha/shadow_rm_act/scripted_policy.py
Normal file
194
realman_src/realman_aloha/shadow_rm_act/scripted_policy.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pyquaternion import Quaternion
|
||||
|
||||
from constants import SIM_TASK_CONFIGS
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
class BasePolicy:
|
||||
def __init__(self, inject_noise=False):
|
||||
self.inject_noise = inject_noise
|
||||
self.step_count = 0
|
||||
self.left_trajectory = None
|
||||
self.right_trajectory = None
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def interpolate(curr_waypoint, next_waypoint, t):
|
||||
t_frac = (t - curr_waypoint["t"]) / (next_waypoint["t"] - curr_waypoint["t"])
|
||||
curr_xyz = curr_waypoint['xyz']
|
||||
curr_quat = curr_waypoint['quat']
|
||||
curr_grip = curr_waypoint['gripper']
|
||||
next_xyz = next_waypoint['xyz']
|
||||
next_quat = next_waypoint['quat']
|
||||
next_grip = next_waypoint['gripper']
|
||||
xyz = curr_xyz + (next_xyz - curr_xyz) * t_frac
|
||||
quat = curr_quat + (next_quat - curr_quat) * t_frac
|
||||
gripper = curr_grip + (next_grip - curr_grip) * t_frac
|
||||
return xyz, quat, gripper
|
||||
|
||||
def __call__(self, ts):
|
||||
# generate trajectory at first timestep, then open-loop execution
|
||||
if self.step_count == 0:
|
||||
self.generate_trajectory(ts)
|
||||
|
||||
# obtain left and right waypoints
|
||||
if self.left_trajectory[0]['t'] == self.step_count:
|
||||
self.curr_left_waypoint = self.left_trajectory.pop(0)
|
||||
next_left_waypoint = self.left_trajectory[0]
|
||||
|
||||
if self.right_trajectory[0]['t'] == self.step_count:
|
||||
self.curr_right_waypoint = self.right_trajectory.pop(0)
|
||||
next_right_waypoint = self.right_trajectory[0]
|
||||
|
||||
# interpolate between waypoints to obtain current pose and gripper command
|
||||
left_xyz, left_quat, left_gripper = self.interpolate(self.curr_left_waypoint, next_left_waypoint, self.step_count)
|
||||
right_xyz, right_quat, right_gripper = self.interpolate(self.curr_right_waypoint, next_right_waypoint, self.step_count)
|
||||
|
||||
# Inject noise
|
||||
if self.inject_noise:
|
||||
scale = 0.01
|
||||
left_xyz = left_xyz + np.random.uniform(-scale, scale, left_xyz.shape)
|
||||
right_xyz = right_xyz + np.random.uniform(-scale, scale, right_xyz.shape)
|
||||
|
||||
action_left = np.concatenate([left_xyz, left_quat, [left_gripper]])
|
||||
action_right = np.concatenate([right_xyz, right_quat, [right_gripper]])
|
||||
|
||||
self.step_count += 1
|
||||
return np.concatenate([action_left, action_right])
|
||||
|
||||
|
||||
class PickAndTransferPolicy(BasePolicy):
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||
|
||||
box_info = np.array(ts_first.observation['env_state'])
|
||||
box_xyz = box_info[:3]
|
||||
box_quat = box_info[3:]
|
||||
# print(f"Generate trajectory for {box_xyz=}")
|
||||
|
||||
gripper_pick_quat = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat = gripper_pick_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||
|
||||
meet_left_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)
|
||||
|
||||
meet_xyz = np.array([0, 0.5, 0.25])
|
||||
|
||||
self.left_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||
{"t": 100, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # approach meet position
|
||||
{"t": 260, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # move to meet position
|
||||
{"t": 310, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 0}, # close gripper
|
||||
{"t": 360, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # move left
|
||||
{"t": 400, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # stay
|
||||
]
|
||||
|
||||
self.right_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||
{"t": 90, "xyz": box_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 130, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 1}, # go down
|
||||
{"t": 170, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 0}, # close gripper
|
||||
{"t": 200, "xyz": meet_xyz + np.array([0.05, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 220, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 0}, # move to meet position
|
||||
{"t": 310, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 1}, # open gripper
|
||||
{"t": 360, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # move to right
|
||||
{"t": 400, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # stay
|
||||
]
|
||||
|
||||
|
||||
class InsertionPolicy(BasePolicy):
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||
|
||||
peg_info = np.array(ts_first.observation['env_state'])[:7]
|
||||
peg_xyz = peg_info[:3]
|
||||
peg_quat = peg_info[3:]
|
||||
|
||||
socket_info = np.array(ts_first.observation['env_state'])[7:]
|
||||
socket_xyz = socket_info[:3]
|
||||
socket_quat = socket_info[3:]
|
||||
|
||||
gripper_pick_quat_right = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat_right = gripper_pick_quat_right * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||
|
||||
gripper_pick_quat_left = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat_left = gripper_pick_quat_left * Quaternion(axis=[0.0, 1.0, 0.0], degrees=60)
|
||||
|
||||
meet_xyz = np.array([0, 0.5, 0.15])
|
||||
lift_right = 0.00715
|
||||
|
||||
self.left_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||
{"t": 120, "xyz": socket_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 170, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # go down
|
||||
{"t": 220, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # close gripper
|
||||
{"t": 285, "xyz": meet_xyz + np.array([-0.1, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 340, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements,"gripper": 0}, # insertion
|
||||
{"t": 400, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # insertion
|
||||
]
|
||||
|
||||
self.right_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||
{"t": 120, "xyz": peg_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 170, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # go down
|
||||
{"t": 220, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # close gripper
|
||||
{"t": 285, "xyz": meet_xyz + np.array([0.1, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 340, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||
{"t": 400, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||
|
||||
]
|
||||
|
||||
|
||||
def test_policy(task_name):
|
||||
# example rolling out pick_and_transfer policy
|
||||
onscreen_render = True
|
||||
inject_noise = False
|
||||
|
||||
# setup the environment
|
||||
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
env = make_ee_sim_env('sim_transfer_cube')
|
||||
elif 'sim_insertion' in task_name:
|
||||
env = make_ee_sim_env('sim_insertion')
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for episode_idx in range(2):
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['angle'])
|
||||
plt.ion()
|
||||
|
||||
policy = PickAndTransferPolicy(inject_noise)
|
||||
for step in range(episode_len):
|
||||
action = policy(ts)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images']['angle'])
|
||||
plt.pause(0.02)
|
||||
plt.close()
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||
if episode_return > 0:
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_task_name = 'sim_transfer_cube_scripted'
|
||||
test_policy(test_task_name)
|
||||
|
||||
278
realman_src/realman_aloha/shadow_rm_act/sim_env.py
Normal file
278
realman_src/realman_aloha/shadow_rm_act/sim_env.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import collections
|
||||
import matplotlib.pyplot as plt
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
|
||||
from constants import DT, XML_DIR, START_ARM_POSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||
from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
BOX_POSE = [None] # to be changed from outside
|
||||
|
||||
def make_sim_env(task_name):
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with joint position control
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_transfer_cube.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = TransferCubeTask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
elif 'sim_insertion' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_insertion.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = InsertionTask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
class BimanualViperXTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
left_arm_action = action[:6]
|
||||
right_arm_action = action[7:7+6]
|
||||
normalized_left_gripper_action = action[6]
|
||||
normalized_right_gripper_action = action[7+6]
|
||||
|
||||
left_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_left_gripper_action)
|
||||
right_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_right_gripper_action)
|
||||
|
||||
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
|
||||
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
|
||||
|
||||
env_action = np.concatenate([left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action])
|
||||
super().before_step(env_action, physics)
|
||||
return
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
obs = collections.OrderedDict()
|
||||
obs['qpos'] = self.get_qpos(physics)
|
||||
obs['qvel'] = self.get_qvel(physics)
|
||||
obs['env_state'] = self.get_env_state(physics)
|
||||
obs['images'] = dict()
|
||||
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
|
||||
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
|
||||
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
|
||||
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7:] = BOX_POSE[0]
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7*2:] = BOX_POSE[0] # two objects
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||
("socket-2", "table") in all_contact_pairs or \
|
||||
("socket-3", "table") in all_contact_pairs or \
|
||||
("socket-4", "table") in all_contact_pairs
|
||||
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||
("red_peg", "socket-2") in all_contact_pairs or \
|
||||
("red_peg", "socket-3") in all_contact_pairs or \
|
||||
("red_peg", "socket-4") in all_contact_pairs
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14)
|
||||
# arm action
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7:7+6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# gripper action
|
||||
left_gripper_pos = master_bot_left.dxl.joint_states.position[7]
|
||||
right_gripper_pos = master_bot_right.dxl.joint_states.position[7]
|
||||
normalized_left_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(left_gripper_pos)
|
||||
normalized_right_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(right_gripper_pos)
|
||||
action[6] = normalized_left_pos
|
||||
action[7+6] = normalized_right_pos
|
||||
return action
|
||||
|
||||
def test_sim_teleop():
|
||||
""" Testing teleoperation in sim with ALOHA. Requires hardware and ALOHA repo to work. """
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
|
||||
BOX_POSE[0] = [0.2, 0.5, 0.05, 1, 0, 0, 0]
|
||||
|
||||
# source of data
|
||||
master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||
robot_name=f'master_left', init_node=True)
|
||||
master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||
robot_name=f'master_right', init_node=False)
|
||||
|
||||
# setup the environment
|
||||
env = make_sim_env('sim_transfer_cube')
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
# setup plotting
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['angle'])
|
||||
plt.ion()
|
||||
|
||||
for t in range(1000):
|
||||
action = get_action(master_bot_left, master_bot_right)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
|
||||
plt_img.set_data(ts.observation['images']['angle'])
|
||||
plt.pause(0.02)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sim_teleop()
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,575 @@
|
||||
import os
|
||||
import time
|
||||
import yaml
|
||||
import torch
|
||||
import pickle
|
||||
import dm_env
|
||||
import logging
|
||||
import collections
|
||||
import numpy as np
|
||||
import tracemalloc
|
||||
from einops import rearrange
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import transforms
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_act.models.latent_model import Latent_Model_Transformer
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.utils.utils import set_seed
|
||||
|
||||
|
||||
# 配置logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
# # 隐藏h5py的警告Creating converter from 7 to 5
|
||||
# logging.getLogger("h5py").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RmActEvaluator:
|
||||
def __init__(self, config, save_episode=True, num_rollouts=50):
|
||||
"""
|
||||
初始化Evaluator类
|
||||
|
||||
Args:
|
||||
config (dict): 配置字典
|
||||
checkpoint_name (str): 检查点名称
|
||||
save_episode (bool): 是否保存每个episode
|
||||
num_rollouts (int): 滚动次数
|
||||
"""
|
||||
self.config = config
|
||||
self._seed = config["seed"]
|
||||
self.robot_env = config["robot_env"]
|
||||
self.checkpoint_dir = config["checkpoint_dir"]
|
||||
self.checkpoint_name = config["checkpoint_name"]
|
||||
self.save_episode = save_episode
|
||||
self.num_rollouts = num_rollouts
|
||||
self.state_dim = config["state_dim"]
|
||||
self.real_robot = config["real_robot"]
|
||||
self.policy_class = config["policy_class"]
|
||||
self.onscreen_render = config["onscreen_render"]
|
||||
self.camera_names = config["camera_names"]
|
||||
self.max_timesteps = config["episode_len"]
|
||||
self.task_name = config["task_name"]
|
||||
self.temporal_agg = config["temporal_agg"]
|
||||
self.onscreen_cam = "angle"
|
||||
self.policy_config = config["policy_config"]
|
||||
self.vq = config["policy_config"]["vq"]
|
||||
# self.actuator_config = config["actuator_config"]
|
||||
# self.use_actuator_net = self.actuator_config["actuator_network_dir"] is not None
|
||||
self.stats = None
|
||||
self.env = None
|
||||
self.env_max_reward = 0
|
||||
|
||||
def _make_policy(self, policy_class, policy_config):
|
||||
"""
|
||||
根据策略类和配置创建策略对象
|
||||
|
||||
Args:
|
||||
policy_class (str): 策略类名称
|
||||
policy_config (dict): 策略配置字典
|
||||
|
||||
Returns:
|
||||
policy: 创建的策略对象
|
||||
"""
|
||||
if policy_class == "ACT":
|
||||
return ACTPolicy(policy_config)
|
||||
elif policy_class == "CNNMLP":
|
||||
return CNNMLPPolicy(policy_config)
|
||||
elif policy_class == "Diffusion":
|
||||
return DiffusionPolicy(policy_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Policy class {policy_class} is not implemented")
|
||||
|
||||
def load_policy_and_stats(self):
|
||||
"""
|
||||
加载策略和统计数据
|
||||
"""
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name)
|
||||
logging.info(f"Loading policy from: {checkpoint_path}")
|
||||
self.policy = self._make_policy(self.policy_class, self.policy_config)
|
||||
# 加载模型并设置为评估模式
|
||||
self.policy.load_state_dict(torch.load(checkpoint_path, weights_only=True))
|
||||
self.policy.cuda()
|
||||
self.policy.eval()
|
||||
|
||||
if self.vq:
|
||||
vq_dim = self.config["policy_config"]["vq_dim"]
|
||||
vq_class = self.config["policy_config"]["vq_class"]
|
||||
self.latent_model = Latent_Model_Transformer(vq_dim, vq_dim, vq_class)
|
||||
latent_model_checkpoint_path = os.path.join(
|
||||
self.checkpoint_dir, "latent_model_last.ckpt"
|
||||
)
|
||||
self.latent_model.deserialize(torch.load(latent_model_checkpoint_path))
|
||||
self.latent_model.eval()
|
||||
self.latent_model.cuda()
|
||||
logging.info(
|
||||
f"Loaded policy from: {checkpoint_path}, latent model from: {latent_model_checkpoint_path}"
|
||||
)
|
||||
else:
|
||||
logging.info(f"Loaded: {checkpoint_path}")
|
||||
|
||||
stats_path = os.path.join(self.checkpoint_dir, "dataset_stats.pkl")
|
||||
with open(stats_path, "rb") as f:
|
||||
self.stats = pickle.load(f)
|
||||
|
||||
def pre_process(self, state_qpos):
|
||||
"""
|
||||
预处理状态位置
|
||||
|
||||
Args:
|
||||
state_qpos (np.array): 状态位置数组
|
||||
|
||||
Returns:
|
||||
np.array: 预处理后的状态位置
|
||||
"""
|
||||
if self.policy_class == "Diffusion":
|
||||
return ((state_qpos + 1) / 2) * (
|
||||
self.stats["action_max"] - self.stats["action_min"]
|
||||
) + self.stats["action_min"]
|
||||
# 标准化处理,均值为 0,标准差为 1
|
||||
|
||||
return (state_qpos - self.stats["qpos_mean"]) / self.stats["qpos_std"]
|
||||
|
||||
def post_process(self, action):
|
||||
"""
|
||||
后处理动作
|
||||
|
||||
Args:
|
||||
action (np.array): 动作数组
|
||||
|
||||
Returns:
|
||||
np.array: 后处理后的动作
|
||||
"""
|
||||
# 反标准化处理
|
||||
return action * self.stats["action_std"] + self.stats["action_mean"]
|
||||
|
||||
def get_image_torch(self, timestep, camera_names, random_crop_resize=False):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Args:
|
||||
timestep (object): 时间步对象
|
||||
camera_names (list): 相机名称列表
|
||||
random_crop_resize (bool): 是否随机裁剪和调整大小
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 处理后的图像,归一化(num_cameras, channels, height, width)
|
||||
"""
|
||||
current_images = []
|
||||
for cam_name in camera_names:
|
||||
current_image = rearrange(
|
||||
timestep.observation["images"][cam_name], "h w c -> c h w"
|
||||
)
|
||||
current_images.append(current_image)
|
||||
current_image = np.stack(current_images, axis=0)
|
||||
current_image = (
|
||||
torch.from_numpy(current_image / 255.0).float().cuda().unsqueeze(0)
|
||||
)
|
||||
|
||||
if random_crop_resize:
|
||||
logging.info("Random crop resize is used!")
|
||||
original_size = current_image.shape[-2:]
|
||||
ratio = 0.95
|
||||
current_image = current_image[
|
||||
...,
|
||||
int(original_size[0] * (1 - ratio) / 2) : int(
|
||||
original_size[0] * (1 + ratio) / 2
|
||||
),
|
||||
int(original_size[1] * (1 - ratio) / 2) : int(
|
||||
original_size[1] * (1 + ratio) / 2
|
||||
),
|
||||
]
|
||||
current_image = current_image.squeeze(0)
|
||||
resize_transform = transforms.Resize(original_size, antialias=True)
|
||||
current_image = resize_transform(current_image)
|
||||
current_image = current_image.unsqueeze(0)
|
||||
|
||||
return current_image
|
||||
|
||||
def load_environment(self):
|
||||
"""
|
||||
加载环境
|
||||
"""
|
||||
if self.real_robot:
|
||||
self.env = DeviceAloha(self.robot_env)
|
||||
self.env_max_reward = 0
|
||||
else:
|
||||
from sim_env import make_sim_env
|
||||
|
||||
self.env = make_sim_env(self.task_name)
|
||||
self.env_max_reward = self.env.task.max_reward
|
||||
|
||||
def get_auto_index(self, checkpoint_dir):
|
||||
max_idx = 1000
|
||||
for i in range(max_idx + 1):
|
||||
if not os.path.isfile(os.path.join(checkpoint_dir, f"qpos_{i}.npy")):
|
||||
return i
|
||||
raise Exception(f"Error getting auto index, or more than {max_idx} episodes")
|
||||
|
||||
def evaluate(self, checkpoint_name=None):
|
||||
"""
|
||||
评估策略
|
||||
|
||||
Returns:
|
||||
tuple: 成功率和平均回报
|
||||
"""
|
||||
if checkpoint_name is not None:
|
||||
self.checkpoint_name = checkpoint_name
|
||||
set_seed(self._seed) # np与torch的随机种子
|
||||
self.load_policy_and_stats()
|
||||
self.load_environment()
|
||||
|
||||
query_frequency = self.policy_config["num_queries"]
|
||||
|
||||
# 时间聚合时,每个时间步只有1个查询
|
||||
if self.temporal_agg:
|
||||
query_frequency = 1
|
||||
num_queries = self.policy_config["num_queries"]
|
||||
|
||||
# # 真实机器人时,基础延迟为13???
|
||||
# if self.real_robot:
|
||||
# BASE_DELAY = 13
|
||||
# # query_frequency -= BASE_DELAY
|
||||
|
||||
max_timesteps = int(self.max_timesteps * 1) # may increase for real-world tasks
|
||||
episode_returns = []
|
||||
highest_rewards = []
|
||||
|
||||
for rollout_id in range(self.num_rollouts):
|
||||
|
||||
timestep = self.env.reset()
|
||||
|
||||
if self.onscreen_render:
|
||||
# TODO 画图
|
||||
pass
|
||||
if self.temporal_agg:
|
||||
all_time_actions = torch.zeros(
|
||||
[max_timesteps, max_timesteps + num_queries, self.state_dim]
|
||||
).cuda()
|
||||
qpos_history_raw = np.zeros((max_timesteps, self.state_dim))
|
||||
rewards = []
|
||||
|
||||
with torch.inference_mode():
|
||||
time_0 = time.time()
|
||||
DT = 1 / 30
|
||||
culmulated_delay = 0
|
||||
for t in range(max_timesteps):
|
||||
time_1 = time.time()
|
||||
if self.onscreen_render:
|
||||
# TODO 显示图像
|
||||
pass
|
||||
# process previous timestep to get qpos and image_list
|
||||
obs = timestep.observation
|
||||
qpos_numpy = np.array(obs["qpos"])
|
||||
qpos_history_raw[t] = qpos_numpy
|
||||
qpos = self.pre_process(qpos_numpy)
|
||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
||||
|
||||
logging.info(f"t{t}")
|
||||
|
||||
if t % query_frequency == 0:
|
||||
current_image = self.get_image_torch(
|
||||
timestep,
|
||||
self.camera_names,
|
||||
random_crop_resize=(
|
||||
self.config["policy_class"] == "Diffusion"
|
||||
),
|
||||
)
|
||||
|
||||
if t == 0:
|
||||
# 网络预热
|
||||
for _ in range(10):
|
||||
self.policy(qpos, current_image)
|
||||
logging.info("Network warm up done")
|
||||
|
||||
if self.config["policy_class"] == "ACT":
|
||||
if t % query_frequency == 0:
|
||||
if self.vq:
|
||||
if rollout_id == 0:
|
||||
for _ in range(10):
|
||||
vq_sample = self.latent_model.generate(
|
||||
1, temperature=1, x=None
|
||||
)
|
||||
logging.info(
|
||||
torch.nonzero(vq_sample[0])[:, 1]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
vq_sample = self.latent_model.generate(
|
||||
1, temperature=1, x=None
|
||||
)
|
||||
all_actions = self.policy(
|
||||
qpos, current_image, vq_sample=vq_sample
|
||||
)
|
||||
else:
|
||||
all_actions = self.policy(qpos, current_image)
|
||||
# if self.real_robot:
|
||||
# all_actions = torch.cat(
|
||||
# [
|
||||
# all_actions[:, :-BASE_DELAY, :-2],
|
||||
# all_actions[:, BASE_DELAY:, -2:],
|
||||
# ],
|
||||
# dim=2,
|
||||
# )
|
||||
if self.temporal_agg:
|
||||
all_time_actions[[t], t : t + num_queries] = all_actions
|
||||
actions_for_curr_step = all_time_actions[:, t]
|
||||
actions_populated = torch.all(
|
||||
actions_for_curr_step != 0, axis=1
|
||||
)
|
||||
actions_for_curr_step = actions_for_curr_step[
|
||||
actions_populated
|
||||
]
|
||||
k = 0.01
|
||||
exp_weights = np.exp(
|
||||
-k * np.arange(len(actions_for_curr_step))
|
||||
)
|
||||
exp_weights = exp_weights / exp_weights.sum()
|
||||
exp_weights = (
|
||||
torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
)
|
||||
raw_action = (actions_for_curr_step * exp_weights).sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
else:
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
elif self.config["policy_class"] == "Diffusion":
|
||||
if t % query_frequency == 0:
|
||||
all_actions = self.policy(qpos, current_image)
|
||||
# if self.real_robot:
|
||||
# all_actions = torch.cat(
|
||||
# [
|
||||
# all_actions[:, :-BASE_DELAY, :-2],
|
||||
# all_actions[:, BASE_DELAY:, -2:],
|
||||
# ],
|
||||
# dim=2,
|
||||
# )
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
elif self.config["policy_class"] == "CNNMLP":
|
||||
raw_action = self.policy(qpos, current_image)
|
||||
all_actions = raw_action.unsqueeze(0)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
### post-process actions
|
||||
raw_action = raw_action.squeeze(0).cpu().numpy()
|
||||
action = self.post_process(raw_action)
|
||||
|
||||
### step the environment
|
||||
if self.real_robot:
|
||||
logging.info(f" action = {action}")
|
||||
timestep = self.env.step(action)
|
||||
|
||||
rewards.append(timestep.reward)
|
||||
duration = time.time() - time_1
|
||||
sleep_time = max(0, DT - duration)
|
||||
time.sleep(sleep_time)
|
||||
if duration >= DT:
|
||||
culmulated_delay += duration - DT
|
||||
logging.warning(
|
||||
f"Warning: step duration: {duration:.3f} s at step {t} longer than DT: {DT} s, culmulated delay: {culmulated_delay:.3f} s"
|
||||
)
|
||||
|
||||
logging.info(f"Avg fps {max_timesteps / (time.time() - time_0)}")
|
||||
plt.close()
|
||||
|
||||
if self.real_robot:
|
||||
log_id = self.get_auto_index(self.checkpoint_dir)
|
||||
np.save(
|
||||
os.path.join(self.checkpoint_dir, f"qpos_{log_id}.npy"),
|
||||
qpos_history_raw,
|
||||
)
|
||||
plt.figure(figsize=(10, 20))
|
||||
for i in range(self.state_dim):
|
||||
plt.subplot(self.state_dim, 1, i + 1)
|
||||
plt.plot(qpos_history_raw[:, i])
|
||||
if i != self.state_dim - 1:
|
||||
plt.xticks([])
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(self.checkpoint_dir, f"qpos_{log_id}.png"))
|
||||
plt.close()
|
||||
|
||||
rewards = np.array(rewards)
|
||||
episode_return = np.sum(rewards[rewards != None])
|
||||
episode_returns.append(episode_return)
|
||||
episode_highest_reward = np.max(rewards)
|
||||
highest_rewards.append(episode_highest_reward)
|
||||
logging.info(
|
||||
f"Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {self.env_max_reward=}, Success: {episode_highest_reward == self.env_max_reward}"
|
||||
)
|
||||
|
||||
success_rate = np.mean(np.array(highest_rewards) == self.env_max_reward)
|
||||
avg_return = np.mean(episode_returns)
|
||||
summary_str = (
|
||||
f"\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n"
|
||||
)
|
||||
for r in range(self.env_max_reward + 1):
|
||||
more_or_equal_r = (np.array(highest_rewards) >= r).sum()
|
||||
more_or_equal_r_rate = more_or_equal_r / self.num_rollouts
|
||||
summary_str += f"Reward >= {r}: {more_or_equal_r}/{self.num_rollouts} = {more_or_equal_r_rate * 100}%\n"
|
||||
|
||||
logging.info(summary_str)
|
||||
|
||||
result_file_name = "result_" + self.checkpoint_name.split(".")[0] + ".txt"
|
||||
with open(os.path.join(self.checkpoint_dir, result_file_name), "w") as f:
|
||||
f.write(summary_str)
|
||||
f.write(repr(episode_returns))
|
||||
f.write("\n\n")
|
||||
f.write(repr(highest_rewards))
|
||||
|
||||
return success_rate, avg_return
|
||||
|
||||
|
||||
class DeviceAloha:
|
||||
def __init__(self, aloha_config):
|
||||
"""
|
||||
初始化设备
|
||||
|
||||
Args:
|
||||
device_name (str): 设备名称
|
||||
"""
|
||||
config_left_arm = aloha_config["rm_left_arm"]
|
||||
config_right_arm = aloha_config["rm_right_arm"]
|
||||
config_head_camera = aloha_config["head_camera"]
|
||||
config_bottom_camera = aloha_config["bottom_camera"]
|
||||
config_left_camera = aloha_config["left_camera"]
|
||||
config_right_camera = aloha_config["right_camera"]
|
||||
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
|
||||
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
|
||||
self.arm_axis = aloha_config["arm_axis"]
|
||||
self.arm_left = RmArm(config_left_arm)
|
||||
self.arm_right = RmArm(config_right_arm)
|
||||
self.camera_left = RealSenseCamera(config_head_camera, False)
|
||||
self.camera_right = RealSenseCamera(config_bottom_camera, False)
|
||||
self.camera_bottom = RealSenseCamera(config_left_camera, False)
|
||||
self.camera_top = RealSenseCamera(config_right_camera, False)
|
||||
self.camera_left.start_camera()
|
||||
self.camera_right.start_camera()
|
||||
self.camera_bottom.start_camera()
|
||||
self.camera_top.start_camera()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭摄像头
|
||||
"""
|
||||
self.camera_left.close()
|
||||
self.camera_right.close()
|
||||
self.camera_bottom.close()
|
||||
self.camera_top.close()
|
||||
|
||||
def get_qps(self):
|
||||
"""
|
||||
获取关节角度
|
||||
|
||||
Returns:
|
||||
np.array: 关节角度
|
||||
"""
|
||||
left_slave_arm_angle = self.arm_left.get_joint_angle()
|
||||
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
|
||||
right_slave_arm_angle = self.arm_right.get_joint_angle()
|
||||
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
|
||||
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
|
||||
|
||||
def get_qvel(self):
|
||||
"""
|
||||
获取关节速度
|
||||
|
||||
Returns:
|
||||
np.array: 关节速度
|
||||
"""
|
||||
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
|
||||
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
|
||||
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
|
||||
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
|
||||
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
|
||||
|
||||
def get_effort(self):
|
||||
"""
|
||||
获取关节力
|
||||
|
||||
Returns:
|
||||
np.array: 关节力
|
||||
"""
|
||||
left_slave_arm_effort = self.arm_left.get_joint_effort()
|
||||
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
|
||||
right_slave_arm_effort = self.arm_right.get_joint_effort()
|
||||
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
|
||||
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Returns:
|
||||
dict: 图像字典
|
||||
"""
|
||||
self.top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
|
||||
self.bottom_image, _, _, _ = self.camera_bottom.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
self.left_image, _, _, _ = self.camera_left.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
self.right_image, _, _, _ = self.camera_right.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
return {
|
||||
"cam_high": self.top_image,
|
||||
"cam_low": self.bottom_image,
|
||||
"cam_left": self.left_image,
|
||||
"cam_right": self.right_image,
|
||||
}
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qps()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def reset(self):
|
||||
logging.info("Resetting the environment")
|
||||
self.arm_left.set_joint_position(self.init_left_arm_angle[0:self.arm_axis])
|
||||
self.arm_right.set_joint_position(self.init_right_arm_angle[0:self.arm_axis])
|
||||
self.arm_left.set_gripper_position(0)
|
||||
self.arm_right.set_gripper_position(0)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
def step(self, target_angle):
|
||||
self.arm_left.set_joint_canfd_position(target_angle[0:self.arm_axis])
|
||||
self.arm_right.set_joint_canfd_position(target_angle[self.arm_axis+1:self.arm_axis*2+1])
|
||||
self.arm_left.set_gripper_position(target_angle[self.arm_axis])
|
||||
self.arm_right.set_gripper_position(target_angle[(self.arm_axis*2 + 1)])
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
|
||||
# config = yaml.safe_load(f)
|
||||
# aloha_config = config["robot_env"]
|
||||
# device = DeviceAloha(aloha_config)
|
||||
# device.reset()
|
||||
# while True:
|
||||
# init_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
|
||||
# time_step = time.time()
|
||||
# timestep = device.step(init_angle)
|
||||
# logging.info(f"Time: {time.time() - time_step}")
|
||||
# obs = timestep.observation
|
||||
|
||||
with open("/home/wang/project/shadow_rm_act/config/config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
# logging.info(f"Config: {config}")
|
||||
evaluator = RmActEvaluator(config)
|
||||
success_rate, avg_return = evaluator.evaluate()
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Backbone modules.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from typing import Dict, List
|
||||
import torch.nn.functional as F
|
||||
from .position_encoding import build_position_encoding
|
||||
from torchvision.models import ResNet18_Weights
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from shadow_act.utils.misc import NestedTensor, is_main_process
|
||||
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super(FrozenBatchNorm2d, self).__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone: nn.Module,
|
||||
train_backbone: bool,
|
||||
num_channels: int,
|
||||
return_interm_layers: bool,
|
||||
):
|
||||
super().__init__()
|
||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||
# parameter.requires_grad_(False)
|
||||
if return_interm_layers:
|
||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
else:
|
||||
return_layers = {"layer4": "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor):
|
||||
xs = self.body(tensor)
|
||||
return xs
|
||||
# out: Dict[str, NestedTensor] = {}
|
||||
# for name, x in xs.items():
|
||||
# m = tensor_list.mask
|
||||
# assert m is not None
|
||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
# out[name] = NestedTensor(x, mask)
|
||||
# return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
train_backbone: bool,
|
||||
return_interm_layers: bool,
|
||||
dilation: bool,
|
||||
):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
weights=ResNet18_Weights.IMAGENET1K_V1 if is_main_process() else None,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# backbone = getattr(torchvision.models, name)(
|
||||
# replace_stride_with_dilation=[False, False, dilation],
|
||||
# pretrained=is_main_process(),
|
||||
# norm_layer=FrozenBatchNorm2d,
|
||||
# ) # pretrained # TODO do we want frozen batch_norm??
|
||||
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for name, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(
|
||||
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
|
||||
):
|
||||
|
||||
position_embedding = build_position_encoding(
|
||||
hidden_dim=hidden_dim, position_embedding_type=position_embedding_type
|
||||
)
|
||||
train_backbone = lr_backbone > 0
|
||||
return_interm_layers = masks
|
||||
backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation)
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = backbone.num_channels
|
||||
return model
|
||||
@@ -0,0 +1,436 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR model and criterion classes.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
from shadow_act.models.transformer import Transformer
|
||||
from .backbone import build_backbone
|
||||
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def reparametrize(mu, logvar):
|
||||
std = logvar.div(2).exp()
|
||||
eps = Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std * eps
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
|
||||
)
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
"""This is the DETR module that performs object detection"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
|
||||
self.state_dim, self.action_dim = state_dim, action_dim
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(
|
||||
backbones[0].num_channels, hidden_dim, kernel_size=1
|
||||
)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(
|
||||
action_dim, hidden_dim
|
||||
) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(
|
||||
action_dim, hidden_dim
|
||||
) # project qpos to embedding
|
||||
if self.vq:
|
||||
self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
|
||||
else:
|
||||
self.latent_proj = nn.Linear(
|
||||
hidden_dim, self.latent_dim * 2
|
||||
) # project hidden state to latent std, var
|
||||
self.register_buffer(
|
||||
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
|
||||
) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
if self.vq:
|
||||
self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
|
||||
else:
|
||||
self.latent_out_proj = nn.Linear(
|
||||
self.latent_dim, hidden_dim
|
||||
) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(
|
||||
2, hidden_dim
|
||||
) # learned position embedding for proprio and latent
|
||||
|
||||
def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
|
||||
bs, _ = qpos.shape
|
||||
if self.encoder is None:
|
||||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
|
||||
qpos.device
|
||||
)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
probs = binaries = mu = logvar = None
|
||||
else:
|
||||
# cvae encoder
|
||||
is_training = actions is not None # train or val
|
||||
### Obtain latent z from action sequence
|
||||
if is_training:
|
||||
# project action sequence to embedding dim, and concat with a CLS token
|
||||
action_embed = self.encoder_action_proj(
|
||||
actions
|
||||
) # (bs, seq, hidden_dim)
|
||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
|
||||
bs, 1, 1
|
||||
) # (bs, 1, hidden_dim)
|
||||
encoder_input = torch.cat(
|
||||
[cls_embed, qpos_embed, action_embed], axis=1
|
||||
) # (bs, seq+1, hidden_dim)
|
||||
encoder_input = encoder_input.permute(
|
||||
1, 0, 2
|
||||
) # (seq+1, bs, hidden_dim)
|
||||
# do not mask cls token
|
||||
cls_joint_is_pad = torch.full((bs, 2), False).to(
|
||||
qpos.device
|
||||
) # False: not a padding
|
||||
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# obtain position embedding
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
encoder_output = self.encoder(
|
||||
encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
|
||||
)
|
||||
encoder_output = encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(encoder_output)
|
||||
|
||||
if self.vq:
|
||||
logits = latent_info.reshape(
|
||||
[*latent_info.shape[:-1], self.vq_class, self.vq_dim]
|
||||
)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
binaries = (
|
||||
F.one_hot(
|
||||
torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(
|
||||
-1
|
||||
),
|
||||
self.vq_dim,
|
||||
)
|
||||
.view(-1, self.vq_class, self.vq_dim)
|
||||
.float()
|
||||
)
|
||||
binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
|
||||
probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
|
||||
straigt_through = binaries_flat - probs_flat.detach() + probs_flat
|
||||
latent_input = self.latent_out_proj(straigt_through)
|
||||
mu = logvar = None
|
||||
else:
|
||||
probs = binaries = None
|
||||
mu = latent_info[:, : self.latent_dim]
|
||||
logvar = latent_info[:, self.latent_dim :]
|
||||
latent_sample = reparametrize(mu, logvar)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
else:
|
||||
mu = logvar = binaries = probs = None
|
||||
if self.vq:
|
||||
latent_input = self.latent_out_proj(
|
||||
vq_sample.view(-1, self.vq_class * self.vq_dim)
|
||||
)
|
||||
else:
|
||||
latent_sample = torch.zeros(
|
||||
[bs, self.latent_dim], dtype=torch.float32
|
||||
).to(qpos.device)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
return latent_input, probs, binaries, mu, logvar
|
||||
|
||||
def forward(
|
||||
self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None
|
||||
):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
|
||||
latent_input, probs, binaries, mu, logvar = self.encode(
|
||||
qpos, actions, is_pad, vq_sample
|
||||
)
|
||||
|
||||
# cvae decoder
|
||||
if self.backbones is not None:
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
all_cam_pos = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
# TODO: fix this error
|
||||
features, pos = self.backbones[0](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0]
|
||||
all_cam_features.append(self.input_proj(features))
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(
|
||||
src,
|
||||
None,
|
||||
self.query_embed.weight,
|
||||
pos,
|
||||
latent_input,
|
||||
proprio_input,
|
||||
self.additional_pos_embed.weight,
|
||||
)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||
hs = self.transformer(
|
||||
transformer_input, None, self.query_embed.weight, self.pos.weight
|
||||
)[0]
|
||||
a_hat = self.action_head(hs)
|
||||
is_pad_hat = self.is_pad_head(hs)
|
||||
return a_hat, is_pad_hat, [mu, logvar], probs, binaries
|
||||
|
||||
|
||||
class CNNMLP(nn.Module):
|
||||
def __init__(self, backbones, state_dim, camera_names):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.camera_names = camera_names
|
||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
backbone_down_projs = []
|
||||
for backbone in backbones:
|
||||
down_proj = nn.Sequential(
|
||||
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
|
||||
nn.Conv2d(128, 64, kernel_size=5),
|
||||
nn.Conv2d(64, 32, kernel_size=5),
|
||||
)
|
||||
backbone_down_projs.append(down_proj)
|
||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||
|
||||
mlp_in_dim = 768 * len(backbones) + state_dim
|
||||
self.mlp = mlp(
|
||||
input_dim=mlp_in_dim,
|
||||
hidden_dim=1024,
|
||||
output_dim=self.action_dim,
|
||||
hidden_depth=2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0] # not used
|
||||
all_cam_features.append(self.backbone_down_projs[cam_id](features))
|
||||
# flatten everything
|
||||
flattened_features = []
|
||||
for cam_feature in all_cam_features:
|
||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
||||
a_hat = self.mlp(features)
|
||||
return a_hat
|
||||
|
||||
|
||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||
if hidden_depth == 0:
|
||||
mods = [nn.Linear(input_dim, output_dim)]
|
||||
else:
|
||||
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
for i in range(hidden_depth - 1):
|
||||
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||
trunk = nn.Sequential(*mods)
|
||||
return trunk
|
||||
|
||||
|
||||
def build_encoder(
|
||||
hidden_dim, # 256
|
||||
dropout, # 0.1
|
||||
nheads, # 8
|
||||
dim_feedforward,
|
||||
num_encoder_layers, # 4 # TODO shared with VAE decoder
|
||||
normalize_before, # False
|
||||
):
|
||||
activation = "relu"
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
hidden_dim, nheads, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
|
||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build_vae(
|
||||
hidden_dim,
|
||||
state_dim,
|
||||
position_embedding_type,
|
||||
lr_backbone,
|
||||
masks,
|
||||
backbone,
|
||||
dilation,
|
||||
dropout,
|
||||
nheads,
|
||||
dim_feedforward,
|
||||
enc_layers,
|
||||
dec_layers,
|
||||
pre_norm,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
no_encoder,
|
||||
):
|
||||
# TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
backbone = build_backbone(
|
||||
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
|
||||
)
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(
|
||||
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
|
||||
)
|
||||
|
||||
if no_encoder:
|
||||
encoder = None
|
||||
else:
|
||||
encoder = build_encoder(
|
||||
hidden_dim,
|
||||
dropout,
|
||||
nheads,
|
||||
dim_feedforward,
|
||||
enc_layers,
|
||||
pre_norm,
|
||||
)
|
||||
|
||||
model = DETRVAE(
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
||||
|
||||
return model
|
||||
|
||||
# TODO
|
||||
def build_cnnmlp(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
model = CNNMLP(
|
||||
backbones,
|
||||
state_dim=state_dim,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torch
|
||||
|
||||
DROPOUT_RATE = 0.1 # 定义 dropout 率
|
||||
|
||||
# 定义一个因果变压器块
|
||||
class Causal_Transformer_Block(nn.Module):
|
||||
def __init__(self, seq_len, latent_dim, num_head) -> None:
|
||||
"""
|
||||
初始化因果变压器块
|
||||
|
||||
Args:
|
||||
seq_len (int): 序列长度
|
||||
latent_dim (int): 潜在维度
|
||||
num_head (int): 注意力头的数量
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.latent_dim = latent_dim
|
||||
self.ln_1 = nn.LayerNorm(latent_dim) # 层归一化
|
||||
self.attn = nn.MultiheadAttention(latent_dim, num_head, dropout=DROPOUT_RATE, batch_first=True) # 多头注意力机制
|
||||
self.ln_2 = nn.LayerNorm(latent_dim) # 层归一化
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(latent_dim, 4 * latent_dim), # 全连接层
|
||||
nn.GELU(), # GELU 激活函数
|
||||
nn.Linear(4 * latent_dim, latent_dim), # 全连接层
|
||||
nn.Dropout(DROPOUT_RATE), # Dropout
|
||||
)
|
||||
|
||||
# self.register_buffer("attn_mask", torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()) # 注册注意力掩码
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量
|
||||
"""
|
||||
# 创建上三角掩码,防止信息泄露
|
||||
attn_mask = torch.triu(torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool), diagonal=1)
|
||||
x = self.ln_1(x) # 层归一化
|
||||
x = x + self.attn(x, x, x, attn_mask=attn_mask)[0] # 加上注意力输出
|
||||
x = self.ln_2(x) # 层归一化
|
||||
x = x + self.mlp(x) # 加上 MLP 输出
|
||||
|
||||
return x
|
||||
|
||||
# 使用自注意力机制而不是 RNN 来建模潜在空间序列
|
||||
class Latent_Model_Transformer(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, seq_len, latent_dim=256, num_head=8, num_layer=3) -> None:
|
||||
"""
|
||||
初始化潜在模型变压器
|
||||
|
||||
Args:
|
||||
input_dim (int): 输入维度
|
||||
output_dim (int): 输出维度
|
||||
seq_len (int): 序列长度
|
||||
latent_dim (int, optional): 潜在维度,默认值为 256
|
||||
num_head (int, optional): 注意力头的数量,默认值为 8
|
||||
num_layer (int, optional): 变压器层的数量,默认值为 3
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.seq_len = seq_len
|
||||
self.latent_dim = latent_dim
|
||||
self.num_head = num_head
|
||||
self.num_layer = num_layer
|
||||
self.input_layer = nn.Linear(input_dim, latent_dim) # 输入层
|
||||
self.weight_pos_embed = nn.Embedding(seq_len, latent_dim) # 位置嵌入
|
||||
self.attention_blocks = nn.Sequential(
|
||||
nn.Dropout(DROPOUT_RATE), # Dropout
|
||||
*[Causal_Transformer_Block(seq_len, latent_dim, num_head) for _ in range(num_layer)], # 多个因果变压器块
|
||||
nn.LayerNorm(latent_dim) # 层归一化
|
||||
)
|
||||
self.output_layer = nn.Linear(latent_dim, output_dim) # 输出层
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量
|
||||
"""
|
||||
x = self.input_layer(x) # 输入层
|
||||
x = x + self.weight_pos_embed(torch.arange(x.shape[1], device=x.device)) # 加上位置嵌入
|
||||
x = self.attention_blocks(x) # 通过注意力块
|
||||
logits = self.output_layer(x) # 输出层
|
||||
|
||||
return logits
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, n, temperature=0.1, x=None):
|
||||
"""
|
||||
生成序列
|
||||
|
||||
Args:
|
||||
n (int): 生成序列的数量
|
||||
temperature (float, optional): 采样温度,默认值为 0.1
|
||||
x (torch.Tensor, optional): 初始输入张量,默认值为 None
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 生成的序列
|
||||
"""
|
||||
if x is None:
|
||||
x = torch.zeros((n, 1, self.input_dim), device=self.weight_pos_embed.weight.device) # 初始化输入
|
||||
for i in range(self.seq_len):
|
||||
logits = self.forward(x)[:, -1] # 获取最后一个时间步的输出
|
||||
probs = torch.softmax(logits / temperature, dim=-1) # 计算概率分布
|
||||
samples = torch.multinomial(probs, num_samples=1)[..., 0] # 从概率分布中采样
|
||||
samples_one_hot = F.one_hot(samples.long(), num_classes=self.output_dim).float() # 转为 one-hot 编码
|
||||
x = torch.cat([x, samples_one_hot[:, None, :]], dim=1) # 将新采样的结果添加到输入中
|
||||
|
||||
return x[:, 1:, :] # 返回生成的序列(去掉初始的零输入)
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from shadow_act.utils.misc import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor):
|
||||
x = tensor
|
||||
# mask = tensor_list.mask
|
||||
# assert mask is not None
|
||||
# not_mask = ~mask
|
||||
|
||||
not_mask = torch.ones_like(x[0, [0]])
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = torch.cat([
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(hidden_dim, position_embedding_type):
|
||||
N_steps = hidden_dim // 2
|
||||
if position_embedding_type in ('v2', 'sine'):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||
elif position_embedding_type in ('v3', 'learned'):
|
||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {position_embedding_type}")
|
||||
|
||||
return position_embedding
|
||||
@@ -0,0 +1,424 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
"""
|
||||
import copy
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
return_intermediate_dec=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(
|
||||
encoder_layer, num_encoder_layers, encoder_norm
|
||||
)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(
|
||||
decoder_layer,
|
||||
num_decoder_layers,
|
||||
decoder_norm,
|
||||
return_intermediate=return_intermediate_dec,
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask,
|
||||
query_embed,
|
||||
pos_embed,
|
||||
latent_input=None,
|
||||
proprio_input=None,
|
||||
additional_pos_embed=None,
|
||||
):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
|
||||
1, bs, 1
|
||||
) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(
|
||||
tgt,
|
||||
memory,
|
||||
memory_key_padding_mask=mask,
|
||||
pos=pos_embed,
|
||||
query_pos=query_embed,
|
||||
)
|
||||
hs = hs.transpose(1, 2)
|
||||
return hs
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
pos=pos,
|
||||
)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(
|
||||
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
||||
)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(
|
||||
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
||||
)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(
|
||||
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(
|
||||
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def build_transformer(
|
||||
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
|
||||
):
|
||||
return Transformer(
|
||||
d_model=hidden_dim,
|
||||
dropout=dropout,
|
||||
nhead=nheads,
|
||||
dim_feedforward=dim_feedforward,
|
||||
num_encoder_layers=enc_layers,
|
||||
num_decoder_layers=dec_layers,
|
||||
normalize_before=pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,522 @@
|
||||
import torch
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from shadow_act.models.detr_vae import build_vae, build_cnnmlp
|
||||
|
||||
# from diffusers.training_utils import EMAModel
|
||||
# from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
|
||||
# from robomimic.algo.diffusion_policy import replace_bn_with_gn, ConditionalUnet1D
|
||||
|
||||
# from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
# from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: 重构DiffusionPolicy类
|
||||
class DiffusionPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
"""
|
||||
初始化DiffusionPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.camera_names = args_override["camera_names"]
|
||||
self.observation_horizon = args_override["observation_horizon"]
|
||||
self.action_horizon = args_override["action_horizon"]
|
||||
self.prediction_horizon = args_override["prediction_horizon"]
|
||||
self.num_inference_timesteps = args_override["num_inference_timesteps"]
|
||||
self.ema_power = args_override["ema_power"]
|
||||
self.lr = args_override["lr"]
|
||||
self.weight_decay = 0
|
||||
|
||||
self.num_kp = 32
|
||||
self.feature_dimension = 64
|
||||
self.ac_dim = args_override["action_dim"]
|
||||
self.obs_dim = self.feature_dimension * len(self.camera_names) + 14
|
||||
|
||||
backbones = []
|
||||
pools = []
|
||||
linears = []
|
||||
for _ in self.camera_names:
|
||||
backbones.append(
|
||||
ResNet18Conv(input_channel=3, pretrained=False, input_coord_conv=False)
|
||||
)
|
||||
pools.append(
|
||||
SpatialSoftmax(
|
||||
input_shape=[512, 15, 20],
|
||||
num_kp=self.num_kp,
|
||||
temperature=1.0,
|
||||
learnable_temperature=False,
|
||||
noise_std=0.0,
|
||||
)
|
||||
)
|
||||
linears.append(
|
||||
torch.nn.Linear(int(np.prod([self.num_kp, 2])), self.feature_dimension)
|
||||
)
|
||||
backbones = nn.ModuleList(backbones)
|
||||
pools = nn.ModuleList(pools)
|
||||
linears = nn.ModuleList(linears)
|
||||
|
||||
backbones = replace_bn_with_gn(backbones)
|
||||
|
||||
noise_pred_net = ConditionalUnet1D(
|
||||
input_dim=self.ac_dim,
|
||||
global_cond_dim=self.obs_dim * self.observation_horizon,
|
||||
)
|
||||
|
||||
nets = nn.ModuleDict(
|
||||
{
|
||||
"policy": nn.ModuleDict(
|
||||
{
|
||||
"backbones": backbones,
|
||||
"pools": pools,
|
||||
"linears": linears,
|
||||
"noise_pred_net": noise_pred_net,
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
nets = nets.float().cuda()
|
||||
ENABLE_EMA = True
|
||||
if ENABLE_EMA:
|
||||
ema = EMAModel(model=nets, power=self.ema_power)
|
||||
else:
|
||||
ema = None
|
||||
self.nets = nets
|
||||
self.ema = ema
|
||||
|
||||
# 设置噪声调度器
|
||||
self.noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=50,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
set_alpha_to_one=True,
|
||||
steps_offset=0,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in self.parameters())
|
||||
logger.info("number of parameters: %.2fM", n_parameters / 1e6)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.nets.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
||||
)
|
||||
return optimizer
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
B = qpos.shape[0]
|
||||
if actions is not None: # 训练时
|
||||
nets = self.nets
|
||||
all_features = []
|
||||
for cam_id in range(len(self.camera_names)):
|
||||
cam_image = image[:, cam_id]
|
||||
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
|
||||
pool_features = nets["policy"]["pools"][cam_id](cam_features)
|
||||
pool_features = torch.flatten(pool_features, start_dim=1)
|
||||
out_features = nets["policy"]["linears"][cam_id](pool_features)
|
||||
all_features.append(out_features)
|
||||
|
||||
obs_cond = torch.cat(all_features + [qpos], dim=1)
|
||||
|
||||
# 为动作添加噪声
|
||||
noise = torch.randn(actions.shape, device=obs_cond.device)
|
||||
|
||||
# 为每个数据点采样一个扩散迭代
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.noise_scheduler.config.num_train_timesteps,
|
||||
(B,),
|
||||
device=obs_cond.device,
|
||||
).long()
|
||||
|
||||
# 根据每个扩散迭代的噪声幅度向干净动作添加噪声
|
||||
noisy_actions = self.noise_scheduler.add_noise(actions, noise, timesteps)
|
||||
|
||||
# 预测噪声残差
|
||||
noise_pred = nets["policy"]["noise_pred_net"](
|
||||
noisy_actions, timesteps, global_cond=obs_cond
|
||||
)
|
||||
|
||||
# L2损失
|
||||
all_l2 = F.mse_loss(noise_pred, noise, reduction="none")
|
||||
loss = (all_l2 * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l2_loss"] = loss
|
||||
loss_dict["loss"] = loss
|
||||
|
||||
if self.training and self.ema is not None:
|
||||
self.ema.step(nets)
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
To = self.observation_horizon
|
||||
Ta = self.action_horizon
|
||||
Tp = self.prediction_horizon
|
||||
action_dim = self.ac_dim
|
||||
|
||||
nets = self.nets
|
||||
if self.ema is not None:
|
||||
nets = self.ema.averaged_model
|
||||
|
||||
all_features = []
|
||||
for cam_id in range(len(self.camera_names)):
|
||||
cam_image = image[:, cam_id]
|
||||
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
|
||||
pool_features = nets["policy"]["pools"][cam_id](cam_features)
|
||||
pool_features = torch.flatten(pool_features, start_dim=1)
|
||||
out_features = nets["policy"]["linears"][cam_id](pool_features)
|
||||
all_features.append(out_features)
|
||||
|
||||
obs_cond = torch.cat(all_features + [qpos], dim=1)
|
||||
|
||||
# 从高斯噪声初始化动作
|
||||
noisy_action = torch.randn((B, Tp, action_dim), device=obs_cond.device)
|
||||
naction = noisy_action
|
||||
|
||||
# 初始化调度器
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
|
||||
|
||||
for k in self.noise_scheduler.timesteps:
|
||||
# 预测噪声
|
||||
noise_pred = nets["policy"]["noise_pred_net"](
|
||||
sample=naction, timestep=k, global_cond=obs_cond
|
||||
)
|
||||
|
||||
# 逆扩散步骤(去除噪声)
|
||||
naction = self.noise_scheduler.step(
|
||||
model_output=noise_pred, timestep=k, sample=naction
|
||||
).prev_sample
|
||||
|
||||
return naction
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
序列化模型
|
||||
|
||||
Returns:
|
||||
dict: 模型状态字典
|
||||
"""
|
||||
return {
|
||||
"nets": self.nets.state_dict(),
|
||||
"ema": (
|
||||
self.ema.averaged_model.state_dict() if self.ema is not None else None
|
||||
),
|
||||
}
|
||||
|
||||
def deserialize(self, model_dict):
|
||||
"""
|
||||
反序列化模型
|
||||
|
||||
Args:
|
||||
model_dict (dict): 模型状态字典
|
||||
|
||||
Returns:
|
||||
status: 加载状态
|
||||
"""
|
||||
status = self.nets.load_state_dict(model_dict["nets"])
|
||||
logger.info("Loaded model")
|
||||
if model_dict.get("ema", None) is not None:
|
||||
logger.info("Loaded EMA")
|
||||
status_ema = self.ema.averaged_model.load_state_dict(model_dict["ema"])
|
||||
status = [status, status_ema]
|
||||
return status
|
||||
|
||||
|
||||
class ACTPolicy(nn.Module):
|
||||
def __init__(self, act_config):
|
||||
"""
|
||||
初始化ACTPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
lr_backbone = act_config["lr_backbone"]
|
||||
vq = act_config["vq"]
|
||||
lr = act_config["lr"]
|
||||
weight_decay = act_config["weight_decay"]
|
||||
|
||||
model = build_vae(
|
||||
act_config["hidden_dim"],
|
||||
act_config["state_dim"],
|
||||
act_config["position_embedding"],
|
||||
lr_backbone,
|
||||
act_config["masks"],
|
||||
act_config["backbone"],
|
||||
act_config["dilation"],
|
||||
act_config["dropout"],
|
||||
act_config["nheads"],
|
||||
act_config["dim_feedforward"],
|
||||
act_config["enc_layers"],
|
||||
act_config["dec_layers"],
|
||||
act_config["pre_norm"],
|
||||
act_config["num_queries"],
|
||||
act_config["camera_names"],
|
||||
vq,
|
||||
act_config["vq_class"],
|
||||
act_config["vq_dim"],
|
||||
act_config["action_dim"],
|
||||
act_config["no_encoder"],
|
||||
)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" not in n and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" in n and p.requires_grad
|
||||
],
|
||||
"lr": lr_backbone,
|
||||
},
|
||||
]
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
param_dicts, lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
self.model = model # CVAE解码器
|
||||
self.kl_weight = act_config["kl_weight"]
|
||||
self.vq = vq
|
||||
logger.info(f"KL Weight {self.kl_weight}")
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None, vq_sample=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 角度张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
vq_sample (torch.Tensor, optional): VQ样本
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
image = normalize(image)
|
||||
if actions is not None: # 训练时
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
loss_dict = dict()
|
||||
a_hat, is_pad_hat, (mu, logvar), probs, binaries = self.model(
|
||||
qpos, image, env_state, actions, is_pad
|
||||
)
|
||||
if self.vq or self.model.encoder is None:
|
||||
total_kld = [torch.tensor(0.0)]
|
||||
else:
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
if self.vq:
|
||||
loss_dict["vq_discrepancy"] = F.l1_loss(
|
||||
probs, binaries, reduction="mean"
|
||||
)
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
loss_dict["l1"] = l1
|
||||
loss_dict["kl"] = total_kld[0]
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
a_hat, _, (_, _), _, _ = self.model(
|
||||
qpos, image, env_state, vq_sample=vq_sample
|
||||
) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
return self.optimizer
|
||||
|
||||
@torch.no_grad()
|
||||
def vq_encode(self, qpos, actions, is_pad):
|
||||
"""
|
||||
VQ编码
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
actions (torch.Tensor): 动作张量
|
||||
is_pad (torch.Tensor): 填充张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 二进制编码
|
||||
"""
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
_, _, binaries, _, _ = self.model.encode(qpos, actions, is_pad)
|
||||
|
||||
return binaries
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
序列化模型
|
||||
|
||||
Returns:
|
||||
dict: 模型状态字典
|
||||
"""
|
||||
return self.state_dict()
|
||||
|
||||
def deserialize(self, model_dict):
|
||||
"""
|
||||
反序列化模型
|
||||
|
||||
Args:
|
||||
model_dict (dict): 模型状态字典
|
||||
|
||||
Returns:
|
||||
status: 加载状态
|
||||
"""
|
||||
return self.load_state_dict(model_dict)
|
||||
|
||||
|
||||
class CNNMLPPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
"""
|
||||
初始化CNNMLPPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
# parser = argparse.ArgumentParser(
|
||||
# "DETR training and evaluation script", parents=[get_args_parser()]
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
|
||||
# for k, v in args_override.items():
|
||||
# setattr(args, k, v)
|
||||
|
||||
model = build_cnnmlp(args_override)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" not in n and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" in n and p.requires_grad
|
||||
],
|
||||
"lr": args_override.lr_backbone,
|
||||
},
|
||||
]
|
||||
self.model = model # 解码器
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
param_dicts, lr=args_override.lr, weight_decay=args_override.weight_decay
|
||||
)
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
image = normalize(image)
|
||||
if actions is not None: # 训练时
|
||||
actions = actions[:, 0]
|
||||
a_hat = self.model(qpos, image, env_state, actions)
|
||||
mse = F.mse_loss(actions, a_hat)
|
||||
loss_dict = dict()
|
||||
loss_dict["mse"] = mse
|
||||
loss_dict["loss"] = loss_dict["mse"]
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
a_hat = self.model(qpos, image, env_state) # 无动作,从先验中采样
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
return self.optimizer
|
||||
|
||||
|
||||
def kl_divergence(mu, logvar):
|
||||
"""
|
||||
计算KL散度
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): 均值张量
|
||||
logvar (torch.Tensor): 对数方差张量
|
||||
|
||||
Returns:
|
||||
tuple: 总KL散度,维度KL散度,均值KL散度
|
||||
"""
|
||||
batch_size = mu.size(0)
|
||||
assert batch_size != 0
|
||||
if mu.data.ndimension() == 4:
|
||||
mu = mu.view(mu.size(0), mu.size(1))
|
||||
if logvar.data.ndimension() == 4:
|
||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||
|
||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||
total_kld = klds.sum(1).mean(0, True)
|
||||
dimension_wise_kld = klds.mean(0)
|
||||
mean_kld = klds.mean(1).mean(0, True)
|
||||
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
@@ -0,0 +1,245 @@
|
||||
import os
|
||||
import yaml
|
||||
import pickle
|
||||
import torch
|
||||
# import wandb
|
||||
import logging
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from copy import deepcopy
|
||||
from itertools import repeat
|
||||
from shadow_act.utils.utils import (
|
||||
set_seed,
|
||||
load_data,
|
||||
compute_dict_mean,
|
||||
)
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.eval.rm_act_eval import RmActEvaluator
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
class RmActTrainer:
|
||||
def __init__(self, config):
|
||||
"""
|
||||
初始化训练器,设置随机种子,加载数据,保存数据统计信息。
|
||||
"""
|
||||
self._config = config
|
||||
self._num_steps = config["num_steps"]
|
||||
self._ckpt_dir = config["checkpoint_dir"]
|
||||
self._state_dim = config["state_dim"]
|
||||
self._real_robot = config["real_robot"]
|
||||
self._policy_class = config["policy_class"]
|
||||
self._onscreen_render = config["onscreen_render"]
|
||||
self._policy_config = config["policy_config"]
|
||||
self._camera_names = config["camera_names"]
|
||||
self._max_timesteps = config["episode_len"]
|
||||
self._task_name = config["task_name"]
|
||||
self._temporal_agg = config["temporal_agg"]
|
||||
self._onscreen_cam = "angle"
|
||||
self._vq = config["policy_config"]["vq"]
|
||||
self._batch_size = config["batch_size"]
|
||||
|
||||
self._seed = config["seed"]
|
||||
self._eval_every = config["eval_every"]
|
||||
self._validate_every = config["validate_every"]
|
||||
self._save_every = config["save_every"]
|
||||
self._load_pretrain = config["load_pretrain"]
|
||||
self._resume_ckpt_path = config["resume_ckpt_path"]
|
||||
|
||||
if config["name_filter"] is None:
|
||||
name_filter = lambda n : True
|
||||
else:
|
||||
name_filter = config["name_filter"]
|
||||
|
||||
self._eval = RmActEvaluator(config, True, 50)
|
||||
# 加载数据
|
||||
self._train_dataloader, self._val_dataloader, self._stats, _ = load_data(
|
||||
config["dataset_dir"],
|
||||
name_filter,
|
||||
self._camera_names,
|
||||
self._batch_size,
|
||||
self._batch_size,
|
||||
config["chunk_size"],
|
||||
config["skip_mirrored_data"],
|
||||
self._load_pretrain,
|
||||
self._policy_class,
|
||||
config["stats_dir"],
|
||||
config["sample_weights"],
|
||||
config["train_ratio"],
|
||||
)
|
||||
# 保存数据统计信息
|
||||
stats_path = os.path.join(self._ckpt_dir, "dataset_stats.pkl")
|
||||
with open(stats_path, "wb") as f:
|
||||
pickle.dump(self._stats, f)
|
||||
expr_name = self._ckpt_dir.split("/")[-1]
|
||||
|
||||
# wandb.init(
|
||||
# project="train_rm_aloha", reinit=True, entity="train_rm_aloha", name=expr_name
|
||||
# )
|
||||
|
||||
|
||||
def _make_policy(self):
|
||||
"""
|
||||
根据策略类和配置创建策略对象
|
||||
"""
|
||||
if self._policy_class == "ACT":
|
||||
return ACTPolicy(self._policy_config)
|
||||
elif self._policy_class == "CNNMLP":
|
||||
return CNNMLPPolicy(self._policy_config)
|
||||
elif self._policy_class == "Diffusion":
|
||||
return DiffusionPolicy(self._policy_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Policy class {self._policy_class} is not implemented")
|
||||
|
||||
def _make_optimizer(self):
|
||||
"""
|
||||
根据策略类创建优化器
|
||||
"""
|
||||
if self._policy_class in ["ACT", "CNNMLP", "Diffusion"]:
|
||||
return self._policy.configure_optimizers()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _forward_pass(self, data):
|
||||
"""
|
||||
前向传播,计算损失
|
||||
"""
|
||||
image_data, qpos_data, action_data, is_pad = data
|
||||
try:
|
||||
image_data, qpos_data, action_data, is_pad = (
|
||||
image_data.cuda(),
|
||||
qpos_data.cuda(),
|
||||
action_data.cuda(),
|
||||
is_pad.cuda(),
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logging.error(f"CUDA error: {e}")
|
||||
raise
|
||||
return self._policy(qpos_data, image_data, action_data, is_pad)
|
||||
|
||||
def _repeater(self):
|
||||
"""
|
||||
数据加载器的重复器,生成数据
|
||||
"""
|
||||
epoch = 0
|
||||
for loader in repeat(self._train_dataloader):
|
||||
for data in loader:
|
||||
yield data
|
||||
logging.info(f"Epoch {epoch} done")
|
||||
epoch += 1
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
训练模型,保存最佳模型
|
||||
"""
|
||||
set_seed(self._seed)
|
||||
self._policy = self._make_policy()
|
||||
min_val_loss = np.inf
|
||||
best_ckpt_info = None
|
||||
|
||||
# 加载预训练模型
|
||||
if self._load_pretrain:
|
||||
try:
|
||||
loading_status = self._policy.deserialize(
|
||||
torch.load(
|
||||
os.path.join(
|
||||
"/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all",
|
||||
"policy_step_50000_seed_0.ckpt",
|
||||
)
|
||||
)
|
||||
)
|
||||
logging.info(f"loaded! {loading_status}")
|
||||
except FileNotFoundError as e:
|
||||
logging.error(f"Pretrain model not found: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading pretrain model: {e}")
|
||||
|
||||
# 恢复检查点
|
||||
if self._resume_ckpt_path is not None:
|
||||
try:
|
||||
loading_status = self._policy.deserialize(torch.load(self._resume_ckpt_path))
|
||||
logging.info(f"Resume policy from: {self._resume_ckpt_path}, Status: {loading_status}")
|
||||
except FileNotFoundError as e:
|
||||
logging.error(f"Checkpoint not found: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading checkpoint: {e}")
|
||||
|
||||
self._policy.cuda()
|
||||
|
||||
self._optimizer = self._make_optimizer()
|
||||
train_dataloader = self._repeater() # 重复器
|
||||
|
||||
for step in tqdm(range(self._num_steps + 1)):
|
||||
# 验证模型
|
||||
if step % self._validate_every != 0:
|
||||
continue
|
||||
logging.info("validating")
|
||||
with torch.inference_mode():
|
||||
self._policy.eval()
|
||||
validation_dicts = []
|
||||
for batch_idx, data in enumerate(self._val_dataloader):
|
||||
forward_dict = self._forward_pass(data) # forward_dict = {"loss": loss, "kl": kl, "mse": mse}
|
||||
validation_dicts.append(forward_dict)
|
||||
if batch_idx > 50: # 限制验证批次数 TODO 确定批次关联
|
||||
break
|
||||
|
||||
validation_summary = compute_dict_mean(validation_dicts)
|
||||
epoch_val_loss = validation_summary["loss"]
|
||||
if epoch_val_loss < min_val_loss:
|
||||
min_val_loss = epoch_val_loss
|
||||
best_ckpt_info = (
|
||||
step,
|
||||
min_val_loss,
|
||||
deepcopy(self._policy.serialize()),
|
||||
)
|
||||
|
||||
# wandb记录验证结果
|
||||
# for k in list(validation_summary.keys()):
|
||||
# validation_summary[f"val_{k}"] = validation_summary.pop(k)
|
||||
|
||||
# wandb.log(validation_summary, step=step)
|
||||
logging.info(f"Val loss: {epoch_val_loss:.5f}")
|
||||
summary_string = " ".join(f"{k}: {v.item():.3f}" for k, v in validation_summary.items())
|
||||
logging.info(summary_string)
|
||||
|
||||
# 评估模型
|
||||
# if (step > 0) and (step % self._eval_every == 0):
|
||||
# ckpt_name = f"policy_step_{step}_seed_{self._seed}.ckpt"
|
||||
# ckpt_path = os.path.join(self._ckpt_dir, ckpt_name)
|
||||
# torch.save(self._policy.serialize(), ckpt_path)
|
||||
# success, _ = self._eval.evaluate(ckpt_name)
|
||||
# wandb.log({"success": success}, step=step)
|
||||
|
||||
# 训练模型
|
||||
self._policy.train()
|
||||
self._optimizer.zero_grad()
|
||||
data = next(train_dataloader)
|
||||
forward_dict = self._forward_pass(data)
|
||||
loss = forward_dict["loss"]
|
||||
loss.backward()
|
||||
self._optimizer.step()
|
||||
# wandb.log(forward_dict, step=step)
|
||||
|
||||
# 保存检查点
|
||||
if step % self._save_every == 0:
|
||||
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{step}_seed_{self._seed}.ckpt")
|
||||
torch.save(self._policy.serialize(), ckpt_path)
|
||||
|
||||
# 保存最后的模型
|
||||
ckpt_path = os.path.join(self._ckpt_dir, "policy_last.ckpt")
|
||||
torch.save(self._policy.serialize(), ckpt_path)
|
||||
|
||||
best_step, min_val_loss, best_state_dict = best_ckpt_info
|
||||
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{best_step}_seed_{self._seed}.ckpt")
|
||||
torch.save(best_state_dict, ckpt_path)
|
||||
logging.info(f"Training finished:\nSeed {self._seed}, val loss {min_val_loss:.6f} at step {best_step}")
|
||||
|
||||
return best_ckpt_info
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("/home/rm/aloha/shadow_rm_act/config/config.yaml") as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
trainer = RmActTrainer(config)
|
||||
trainer.train()
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Utilities for bounding box manipulation and GIoU.
|
||||
"""
|
||||
import torch
|
||||
from torchvision.ops.boxes import box_area
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
||||
(x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/
|
||||
|
||||
The boxes should be in [x0, y0, x1, y1] format
|
||||
|
||||
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
||||
and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
area = wh[:, :, 0] * wh[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
def masks_to_boxes(masks):
|
||||
"""Compute the bounding boxes around the provided masks
|
||||
|
||||
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
||||
|
||||
Returns a [N, 4] tensors, with the boxes in xyxy format
|
||||
"""
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device)
|
||||
|
||||
h, w = masks.shape[-2:]
|
||||
|
||||
y = torch.arange(0, h, dtype=torch.float)
|
||||
x = torch.arange(0, w, dtype=torch.float)
|
||||
y, x = torch.meshgrid(y, x)
|
||||
|
||||
x_mask = (masks * x.unsqueeze(0))
|
||||
x_max = x_mask.flatten(1).max(-1)[0]
|
||||
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
y_mask = (masks * y.unsqueeze(0))
|
||||
y_max = y_mask.flatten(1).max(-1)[0]
|
||||
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
||||
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
import pickle
|
||||
from packaging import version
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}',
|
||||
'max mem: {memory:.0f}'
|
||||
])
|
||||
else:
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
])
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
||||
sha = 'N/A'
|
||||
diff = "clean"
|
||||
branch = 'N/A'
|
||||
try:
|
||||
sha = _run(['git', 'rev-parse', 'HEAD'])
|
||||
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
||||
diff = _run(['git', 'diff-index', 'HEAD'])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor(object):
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], :img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError('not supported')
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(
|
||||
input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Plotting utilities to visualize training logs.
|
||||
"""
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pathlib import Path, PurePath
|
||||
|
||||
|
||||
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
||||
'''
|
||||
Function to plot specific fields from training log(s). Plots both training and test results.
|
||||
|
||||
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
||||
- fields = which results to plot from each log file - plots both training and test for each field.
|
||||
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
||||
- log_name = optional, name of log file if different than default 'log.txt'.
|
||||
|
||||
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
||||
- solid lines are training results, dashed lines are test results.
|
||||
|
||||
'''
|
||||
func_name = "plot_utils.py::plot_logs"
|
||||
|
||||
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
||||
# convert single Path to list to avoid 'not iterable' error
|
||||
|
||||
if not isinstance(logs, list):
|
||||
if isinstance(logs, PurePath):
|
||||
logs = [logs]
|
||||
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
||||
else:
|
||||
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
||||
Expect list[Path] or single Path obj, received {type(logs)}")
|
||||
|
||||
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
||||
for i, dir in enumerate(logs):
|
||||
if not isinstance(dir, PurePath):
|
||||
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
||||
if not dir.exists():
|
||||
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
||||
# verify log_name exists
|
||||
fn = Path(dir / log_name)
|
||||
if not fn.exists():
|
||||
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
||||
print(f"--> full path of missing log file: {fn}")
|
||||
return
|
||||
|
||||
# load log file(s) and plot
|
||||
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
||||
|
||||
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
||||
|
||||
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
||||
for j, field in enumerate(fields):
|
||||
if field == 'mAP':
|
||||
coco_eval = pd.DataFrame(
|
||||
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
||||
).ewm(com=ewm_col).mean()
|
||||
axs[j].plot(coco_eval, c=color)
|
||||
else:
|
||||
df.interpolate().ewm(com=ewm_col).mean().plot(
|
||||
y=[f'train_{field}', f'test_{field}'],
|
||||
ax=axs[j],
|
||||
color=[color] * 2,
|
||||
style=['-', '--']
|
||||
)
|
||||
for ax, field in zip(axs, fields):
|
||||
ax.legend([Path(p).name for p in logs])
|
||||
ax.set_title(field)
|
||||
|
||||
|
||||
def plot_precision_recall(files, naming_scheme='iter'):
|
||||
if naming_scheme == 'exp_id':
|
||||
# name becomes exp_id
|
||||
names = [f.parts[-3] for f in files]
|
||||
elif naming_scheme == 'iter':
|
||||
names = [f.stem for f in files]
|
||||
else:
|
||||
raise ValueError(f'not supported {naming_scheme}')
|
||||
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
||||
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
||||
data = torch.load(f)
|
||||
# precision is n_iou, n_points, n_cat, n_area, max_det
|
||||
precision = data['precision']
|
||||
recall = data['params'].recThrs
|
||||
scores = data['scores']
|
||||
# take precision for all classes, all areas and 100 detections
|
||||
precision = precision[0, :, :, 0, -1].mean(1)
|
||||
scores = scores[0, :, :, 0, -1].mean(1)
|
||||
prec = precision.mean()
|
||||
rec = data['recall'][0, :, 0, -1].mean()
|
||||
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
||||
f'score={scores.mean():0.3f}, ' +
|
||||
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
||||
)
|
||||
axs[0].plot(recall, precision, c=color)
|
||||
axs[1].plot(recall, scores, c=color)
|
||||
|
||||
axs[0].set_title('Precision / Recall')
|
||||
axs[0].legend(names)
|
||||
axs[1].set_title('Scores / Recall')
|
||||
axs[1].legend(names)
|
||||
return fig, axs
|
||||
@@ -0,0 +1,499 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import h5py
|
||||
import pickle
|
||||
import fnmatch
|
||||
import cv2
|
||||
from time import time
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
|
||||
|
||||
def flatten_list(l):
|
||||
return [item for sublist in l for item in sublist]
|
||||
|
||||
|
||||
class EpisodicDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
episode_ids,
|
||||
episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
):
|
||||
super(EpisodicDataset).__init__()
|
||||
self.episode_ids = episode_ids
|
||||
self.dataset_path_list = dataset_path_list
|
||||
self.camera_names = camera_names
|
||||
self.norm_stats = norm_stats
|
||||
self.episode_len = episode_len
|
||||
self.chunk_size = chunk_size
|
||||
self.cumulative_len = np.cumsum(self.episode_len)
|
||||
self.max_episode_len = max(episode_len)
|
||||
self.policy_class = policy_class
|
||||
if self.policy_class == "Diffusion":
|
||||
self.augment_images = True
|
||||
else:
|
||||
self.augment_images = False
|
||||
self.transformations = None
|
||||
self.__getitem__(0) # initialize self.is_sim and self.transformations
|
||||
self.is_sim = False
|
||||
|
||||
# def __len__(self):
|
||||
# return sum(self.episode_len)
|
||||
|
||||
def _locate_transition(self, index):
|
||||
assert index < self.cumulative_len[-1]
|
||||
episode_index = np.argmax(
|
||||
self.cumulative_len > index
|
||||
) # argmax returns first True index
|
||||
start_ts = index - (
|
||||
self.cumulative_len[episode_index] - self.episode_len[episode_index]
|
||||
)
|
||||
episode_id = self.episode_ids[episode_index]
|
||||
return episode_id, start_ts
|
||||
|
||||
def __getitem__(self, index):
|
||||
episode_id, start_ts = self._locate_transition(index)
|
||||
dataset_path = self.dataset_path_list[episode_id]
|
||||
try:
|
||||
# print(dataset_path)
|
||||
with h5py.File(dataset_path, "r") as root:
|
||||
try: # some legacy data does not have this attribute
|
||||
is_sim = root.attrs["sim"]
|
||||
except:
|
||||
is_sim = False
|
||||
compressed = root.attrs.get("compress", False)
|
||||
if "/base_action" in root:
|
||||
base_action = root["/base_action"][()]
|
||||
base_action = preprocess_base_action(base_action)
|
||||
action = np.concatenate([root["/action"][()], base_action], axis=-1)
|
||||
else:
|
||||
# TODO
|
||||
action = root["/action"][()]
|
||||
# dummy_base_action = np.zeros([action.shape[0], 2])
|
||||
# action = np.concatenate([action, dummy_base_action], axis=-1)
|
||||
original_action_shape = action.shape
|
||||
episode_len = original_action_shape[0]
|
||||
# get observation at start_ts only
|
||||
qpos = root["/observations/qpos"][start_ts]
|
||||
qvel = root["/observations/qvel"][start_ts]
|
||||
image_dict = dict()
|
||||
for cam_name in self.camera_names:
|
||||
image_dict[cam_name] = root[f"/observations/images/{cam_name}"][
|
||||
start_ts
|
||||
]
|
||||
|
||||
if compressed:
|
||||
for cam_name in image_dict.keys():
|
||||
decompressed_image = cv2.imdecode(image_dict[cam_name], 1)
|
||||
image_dict[cam_name] = np.array(decompressed_image)
|
||||
|
||||
# get all actions after and including start_ts
|
||||
if is_sim:
|
||||
action = action[start_ts:]
|
||||
action_len = episode_len - start_ts
|
||||
else:
|
||||
action = action[
|
||||
max(0, start_ts - 1) :
|
||||
] # hack, to make timesteps more aligned
|
||||
action_len = episode_len - max(
|
||||
0, start_ts - 1
|
||||
) # hack, to make timesteps more aligned
|
||||
|
||||
# self.is_sim = is_sim
|
||||
padded_action = np.zeros(
|
||||
(self.max_episode_len, original_action_shape[1]), dtype=np.float32
|
||||
)
|
||||
padded_action[:action_len] = action
|
||||
is_pad = np.zeros(self.max_episode_len)
|
||||
is_pad[action_len:] = 1
|
||||
|
||||
padded_action = padded_action[: self.chunk_size]
|
||||
is_pad = is_pad[: self.chunk_size]
|
||||
|
||||
# new axis for different cameras
|
||||
all_cam_images = []
|
||||
for cam_name in self.camera_names:
|
||||
all_cam_images.append(image_dict[cam_name])
|
||||
all_cam_images = np.stack(all_cam_images, axis=0)
|
||||
|
||||
# construct observations
|
||||
image_data = torch.from_numpy(all_cam_images)
|
||||
qpos_data = torch.from_numpy(qpos).float()
|
||||
action_data = torch.from_numpy(padded_action).float()
|
||||
is_pad = torch.from_numpy(is_pad).bool()
|
||||
|
||||
# channel last
|
||||
image_data = torch.einsum("k h w c -> k c h w", image_data)
|
||||
|
||||
# augmentation
|
||||
if self.transformations is None:
|
||||
print("Initializing transformations")
|
||||
original_size = image_data.shape[2:]
|
||||
ratio = 0.95
|
||||
self.transformations = [
|
||||
transforms.RandomCrop(
|
||||
size=[
|
||||
int(original_size[0] * ratio),
|
||||
int(original_size[1] * ratio),
|
||||
]
|
||||
),
|
||||
transforms.Resize(original_size, antialias=True),
|
||||
transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.3, contrast=0.4, saturation=0.5
|
||||
), # , hue=0.08)
|
||||
]
|
||||
|
||||
if self.augment_images:
|
||||
for transform in self.transformations:
|
||||
image_data = transform(image_data)
|
||||
|
||||
# normalize image and change dtype to float
|
||||
image_data = image_data / 255.0
|
||||
|
||||
if self.policy_class == "Diffusion":
|
||||
# normalize to [-1, 1]
|
||||
action_data = (
|
||||
(action_data - self.norm_stats["action_min"])
|
||||
/ (self.norm_stats["action_max"] - self.norm_stats["action_min"])
|
||||
) * 2 - 1
|
||||
else:
|
||||
# normalize to mean 0 std 1
|
||||
action_data = (
|
||||
action_data - self.norm_stats["action_mean"]
|
||||
) / self.norm_stats["action_std"]
|
||||
|
||||
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[
|
||||
"qpos_std"
|
||||
]
|
||||
|
||||
except:
|
||||
print(f"Error loading {dataset_path} in __getitem__")
|
||||
quit()
|
||||
|
||||
# print(image_data.dtype, qpos_data.dtype, action_data.dtype, is_pad.dtype)
|
||||
return image_data, qpos_data, action_data, is_pad
|
||||
|
||||
|
||||
def get_norm_stats(dataset_path_list):
|
||||
all_qpos_data = []
|
||||
all_action_data = []
|
||||
all_episode_len = []
|
||||
|
||||
for dataset_path in dataset_path_list:
|
||||
try:
|
||||
with h5py.File(dataset_path, "r") as root:
|
||||
qpos = root["/observations/qpos"][()]
|
||||
qvel = root["/observations/qvel"][()]
|
||||
if "/base_action" in root:
|
||||
base_action = root["/base_action"][()]
|
||||
# base_action = preprocess_base_action(base_action)
|
||||
# action = np.concatenate([root["/action"][()], base_action], axis=-1)
|
||||
else:
|
||||
# TODO
|
||||
action = root["/action"][()]
|
||||
# dummy_base_action = np.zeros([action.shape[0], 2])
|
||||
# action = np.concatenate([action, dummy_base_action], axis=-1)
|
||||
except Exception as e:
|
||||
print(f"Error loading {dataset_path} in get_norm_stats")
|
||||
print(e)
|
||||
quit()
|
||||
all_qpos_data.append(torch.from_numpy(qpos))
|
||||
all_action_data.append(torch.from_numpy(action))
|
||||
all_episode_len.append(len(qpos))
|
||||
all_qpos_data = torch.cat(all_qpos_data, dim=0)
|
||||
all_action_data = torch.cat(all_action_data, dim=0)
|
||||
|
||||
# normalize action data
|
||||
action_mean = all_action_data.mean(dim=[0]).float()
|
||||
action_std = all_action_data.std(dim=[0]).float()
|
||||
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
|
||||
|
||||
# normalize qpos data
|
||||
qpos_mean = all_qpos_data.mean(dim=[0]).float()
|
||||
qpos_std = all_qpos_data.std(dim=[0]).float()
|
||||
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
|
||||
|
||||
action_min = all_action_data.min(dim=0).values.float()
|
||||
action_max = all_action_data.max(dim=0).values.float()
|
||||
|
||||
eps = 0.0001
|
||||
stats = {
|
||||
"action_mean": action_mean.numpy(),
|
||||
"action_std": action_std.numpy(),
|
||||
"action_min": action_min.numpy() - eps,
|
||||
"action_max": action_max.numpy() + eps,
|
||||
"qpos_mean": qpos_mean.numpy(),
|
||||
"qpos_std": qpos_std.numpy(),
|
||||
"example_qpos": qpos,
|
||||
}
|
||||
|
||||
return stats, all_episode_len
|
||||
|
||||
|
||||
def find_all_hdf5(dataset_dir, skip_mirrored_data):
|
||||
hdf5_files = []
|
||||
for root, dirs, files in os.walk(dataset_dir):
|
||||
for filename in fnmatch.filter(files, "*.hdf5"):
|
||||
if "features" in filename:
|
||||
continue
|
||||
if skip_mirrored_data and "mirror" in filename:
|
||||
continue
|
||||
hdf5_files.append(os.path.join(root, filename))
|
||||
print(f"Found {len(hdf5_files)} hdf5 files")
|
||||
return hdf5_files
|
||||
|
||||
|
||||
def BatchSampler(batch_size, episode_len_l, sample_weights):
|
||||
sample_probs = (
|
||||
np.array(sample_weights) / np.sum(sample_weights)
|
||||
if sample_weights is not None
|
||||
else None
|
||||
)
|
||||
# print("BatchSampler", sample_weights)
|
||||
sum_dataset_len_l = np.cumsum(
|
||||
[0] + [np.sum(episode_len) for episode_len in episode_len_l]
|
||||
)
|
||||
while True:
|
||||
batch = []
|
||||
for _ in range(batch_size):
|
||||
episode_idx = np.random.choice(len(episode_len_l), p=sample_probs)
|
||||
step_idx = np.random.randint(
|
||||
sum_dataset_len_l[episode_idx], sum_dataset_len_l[episode_idx + 1]
|
||||
)
|
||||
batch.append(step_idx)
|
||||
yield batch
|
||||
|
||||
|
||||
def load_data(
|
||||
dataset_dir_l,
|
||||
name_filter,
|
||||
camera_names,
|
||||
batch_size_train,
|
||||
batch_size_val,
|
||||
chunk_size,
|
||||
skip_mirrored_data=False,
|
||||
load_pretrain=False,
|
||||
policy_class=None,
|
||||
stats_dir_l=None,
|
||||
sample_weights=None,
|
||||
train_ratio=0.99,
|
||||
):
|
||||
if type(dataset_dir_l) == str:
|
||||
dataset_dir_l = [dataset_dir_l]
|
||||
dataset_path_list_list = [
|
||||
find_all_hdf5(dataset_dir, skip_mirrored_data) for dataset_dir in dataset_dir_l
|
||||
]
|
||||
num_episodes_0 = len(dataset_path_list_list[0])
|
||||
dataset_path_list = flatten_list(dataset_path_list_list)
|
||||
|
||||
dataset_path_list = [n for n in dataset_path_list if name_filter(n)]
|
||||
num_episodes_l = [
|
||||
len(dataset_path_list) for dataset_path_list in dataset_path_list_list
|
||||
]
|
||||
num_episodes_cumsum = np.cumsum(num_episodes_l)
|
||||
|
||||
# obtain train test split on dataset_dir_l[0]
|
||||
shuffled_episode_ids_0 = np.random.permutation(num_episodes_0)
|
||||
train_episode_ids_0 = shuffled_episode_ids_0[: int(train_ratio * num_episodes_0)]
|
||||
val_episode_ids_0 = shuffled_episode_ids_0[int(train_ratio * num_episodes_0) :]
|
||||
train_episode_ids_l = [train_episode_ids_0] + [
|
||||
np.arange(num_episodes) + num_episodes_cumsum[idx]
|
||||
for idx, num_episodes in enumerate(num_episodes_l[1:])
|
||||
]
|
||||
val_episode_ids_l = [val_episode_ids_0]
|
||||
train_episode_ids = np.concatenate(train_episode_ids_l)
|
||||
val_episode_ids = np.concatenate(val_episode_ids_l)
|
||||
print(
|
||||
f"\n\nData from: {dataset_dir_l}\n- Train on {[len(x) for x in train_episode_ids_l]} episodes\n- Test on {[len(x) for x in val_episode_ids_l]} episodes\n\n"
|
||||
)
|
||||
|
||||
# obtain normalization stats for qpos and action
|
||||
# if load_pretrain:
|
||||
# with open(os.path.join('/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all', 'dataset_stats.pkl'), 'rb') as f:
|
||||
# norm_stats = pickle.load(f)
|
||||
# print('Loaded pretrain dataset stats')
|
||||
_, all_episode_len = get_norm_stats(dataset_path_list)
|
||||
train_episode_len_l = [
|
||||
[all_episode_len[i] for i in train_episode_ids]
|
||||
for train_episode_ids in train_episode_ids_l
|
||||
]
|
||||
val_episode_len_l = [
|
||||
[all_episode_len[i] for i in val_episode_ids]
|
||||
for val_episode_ids in val_episode_ids_l
|
||||
]
|
||||
|
||||
train_episode_len = flatten_list(train_episode_len_l)
|
||||
val_episode_len = flatten_list(val_episode_len_l)
|
||||
if stats_dir_l is None:
|
||||
stats_dir_l = dataset_dir_l
|
||||
elif type(stats_dir_l) == str:
|
||||
stats_dir_l = [stats_dir_l]
|
||||
norm_stats, _ = get_norm_stats(
|
||||
flatten_list(
|
||||
[find_all_hdf5(stats_dir, skip_mirrored_data) for stats_dir in stats_dir_l]
|
||||
)
|
||||
)
|
||||
print(f"Norm stats from: {stats_dir_l}")
|
||||
|
||||
batch_sampler_train = BatchSampler(
|
||||
batch_size_train, train_episode_len_l, sample_weights
|
||||
)
|
||||
batch_sampler_val = BatchSampler(batch_size_val, val_episode_len_l, None)
|
||||
|
||||
# print(f'train_episode_len: {train_episode_len}, val_episode_len: {val_episode_len}, train_episode_ids: {train_episode_ids}, val_episode_ids: {val_episode_ids}')
|
||||
|
||||
# construct dataset and dataloader
|
||||
train_dataset = EpisodicDataset(
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
train_episode_ids,
|
||||
train_episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
)
|
||||
val_dataset = EpisodicDataset(
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
val_episode_ids,
|
||||
val_episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
)
|
||||
train_num_workers = (
|
||||
(8 if os.getlogin() == "zfu" else 16) if train_dataset.augment_images else 2
|
||||
)
|
||||
val_num_workers = 8 if train_dataset.augment_images else 2
|
||||
print(
|
||||
f"Augment images: {train_dataset.augment_images}, train_num_workers: {train_num_workers}, val_num_workers: {val_num_workers}"
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=batch_sampler_train,
|
||||
pin_memory=True,
|
||||
num_workers=train_num_workers,
|
||||
prefetch_factor=2,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_sampler=batch_sampler_val,
|
||||
pin_memory=True,
|
||||
num_workers=val_num_workers,
|
||||
prefetch_factor=2,
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim
|
||||
|
||||
|
||||
def calibrate_linear_vel(base_action, c=None):
|
||||
if c is None:
|
||||
c = 0.0 # 0.19
|
||||
v = base_action[..., 0]
|
||||
w = base_action[..., 1]
|
||||
base_action = base_action.copy()
|
||||
base_action[..., 0] = v - c * w
|
||||
return base_action
|
||||
|
||||
|
||||
def smooth_base_action(base_action):
|
||||
return np.stack(
|
||||
[
|
||||
np.convolve(base_action[:, i], np.ones(5) / 5, mode="same")
|
||||
for i in range(base_action.shape[1])
|
||||
],
|
||||
axis=-1,
|
||||
).astype(np.float32)
|
||||
|
||||
|
||||
def preprocess_base_action(base_action):
|
||||
# base_action = calibrate_linear_vel(base_action)
|
||||
base_action = smooth_base_action(base_action)
|
||||
|
||||
return base_action
|
||||
|
||||
|
||||
def postprocess_base_action(base_action):
|
||||
linear_vel, angular_vel = base_action
|
||||
linear_vel *= 1.0
|
||||
angular_vel *= 1.0
|
||||
# angular_vel = 0
|
||||
# if np.abs(linear_vel) < 0.05:
|
||||
# linear_vel = 0
|
||||
return np.array([linear_vel, angular_vel])
|
||||
|
||||
|
||||
### env utils
|
||||
|
||||
|
||||
def sample_box_pose():
|
||||
x_range = [0.0, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
cube_quat = np.array([1, 0, 0, 0])
|
||||
return np.concatenate([cube_position, cube_quat])
|
||||
|
||||
|
||||
def sample_insertion_pose():
|
||||
# Peg
|
||||
x_range = [0.1, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
peg_quat = np.array([1, 0, 0, 0])
|
||||
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||
|
||||
# Socket
|
||||
x_range = [-0.2, -0.1]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
socket_quat = np.array([1, 0, 0, 0])
|
||||
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||
|
||||
return peg_pose, socket_pose
|
||||
|
||||
|
||||
### helper functions
|
||||
|
||||
|
||||
def compute_dict_mean(epoch_dicts):
|
||||
result = {k: None for k in epoch_dicts[0]}
|
||||
num_items = len(epoch_dicts)
|
||||
for k in result:
|
||||
value_sum = 0
|
||||
for epoch_dict in epoch_dicts:
|
||||
value_sum += epoch_dict[k]
|
||||
result[k] = value_sum / num_items
|
||||
return result
|
||||
|
||||
|
||||
def detach_dict(d):
|
||||
new_d = dict()
|
||||
for k, v in d.items():
|
||||
new_d[k] = v.detach()
|
||||
return new_d
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
163
realman_src/realman_aloha/shadow_rm_act/test/test_camera.py
Normal file
163
realman_src/realman_aloha/shadow_rm_act/test/test_camera.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
import yaml
|
||||
import time
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
import collections
|
||||
import logging
|
||||
import dm_env
|
||||
import tracemalloc
|
||||
|
||||
|
||||
class DeviceAloha:
|
||||
def __init__(self, aloha_config):
|
||||
"""
|
||||
初始化设备
|
||||
|
||||
Args:
|
||||
device_name (str): 设备名称
|
||||
"""
|
||||
config_left_arm = aloha_config["rm_left_arm"]
|
||||
config_right_arm = aloha_config["rm_right_arm"]
|
||||
config_head_camera = aloha_config["head_camera"]
|
||||
config_bottom_camera = aloha_config["bottom_camera"]
|
||||
config_left_camera = aloha_config["left_camera"]
|
||||
config_right_camera = aloha_config["right_camera"]
|
||||
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
|
||||
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
|
||||
self.arm_left = RmArm(config_left_arm)
|
||||
self.arm_right = RmArm(config_right_arm)
|
||||
self.camera_left = RealSenseCamera(config_head_camera, False)
|
||||
self.camera_right = RealSenseCamera(config_bottom_camera, False)
|
||||
self.camera_bottom = RealSenseCamera(config_left_camera, False)
|
||||
self.camera_top = RealSenseCamera(config_right_camera, False)
|
||||
self.camera_left.start_camera()
|
||||
self.camera_right.start_camera()
|
||||
self.camera_bottom.start_camera()
|
||||
self.camera_top.start_camera()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭摄像头
|
||||
"""
|
||||
self.camera_left.close()
|
||||
self.camera_right.close()
|
||||
self.camera_bottom.close()
|
||||
self.camera_top.close()
|
||||
|
||||
def get_qps(self):
|
||||
"""
|
||||
获取关节角度
|
||||
|
||||
Returns:
|
||||
np.array: 关节角度
|
||||
"""
|
||||
left_slave_arm_angle = self.arm_left.get_joint_angle()
|
||||
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
|
||||
right_slave_arm_angle = self.arm_right.get_joint_angle()
|
||||
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
|
||||
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
|
||||
|
||||
def get_qvel(self):
|
||||
"""
|
||||
获取关节速度
|
||||
|
||||
Returns:
|
||||
np.array: 关节速度
|
||||
"""
|
||||
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
|
||||
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
|
||||
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
|
||||
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
|
||||
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
|
||||
|
||||
def get_effort(self):
|
||||
"""
|
||||
获取关节力
|
||||
|
||||
Returns:
|
||||
np.array: 关节力
|
||||
"""
|
||||
left_slave_arm_effort = self.arm_left.get_joint_effort()
|
||||
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
|
||||
right_slave_arm_effort = self.arm_right.get_joint_effort()
|
||||
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
|
||||
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Returns:
|
||||
dict: 图像字典
|
||||
"""
|
||||
top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
|
||||
bottom_image, _, _, _ = self.camera_bottom.read_frame(True, False, False, False)
|
||||
left_image, _, _, _ = self.camera_left.read_frame(True, False, False, False)
|
||||
right_image, _, _, _ = self.camera_right.read_frame(True, False, False, False)
|
||||
return {
|
||||
"cam_high": top_image,
|
||||
"cam_low": bottom_image,
|
||||
"cam_left": left_image,
|
||||
"cam_right": right_image,
|
||||
}
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qps()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
# self.get_images()
|
||||
return obs
|
||||
|
||||
def reset(self):
|
||||
logging.info("Resetting the environment")
|
||||
_ = self.arm_left.set_joint_position(self.init_left_arm_angle[0:6])
|
||||
_ = self.arm_right.set_joint_position(self.init_right_arm_angle[0:6])
|
||||
self.arm_left.set_gripper_position(0)
|
||||
self.arm_right.set_gripper_position(0)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
def step(self, target_angle):
|
||||
self.arm_left.set_joint_canfd_position(target_angle[0:6])
|
||||
self.arm_right.set_joint_canfd_position(target_angle[7:13])
|
||||
self.arm_left.set_gripper_position(target_angle[6])
|
||||
self.arm_right.set_gripper_position(target_angle[13])
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
aloha_config = config["robot_env"]
|
||||
device = DeviceAloha(aloha_config)
|
||||
device.reset()
|
||||
image_list = []
|
||||
tager_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
|
||||
while True:
|
||||
tracemalloc.start() # 启动内存跟踪
|
||||
|
||||
tager_angle = np.array([angle + 0.1 if i not in [6, 13] else angle for i, angle in enumerate(tager_angle)])
|
||||
time_step = time.time()
|
||||
timestep = device.step(tager_angle)
|
||||
logging.info(f"Time: {time.time() - time_step}")
|
||||
image_list.append(timestep.observation["images"])
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
top_stats = snapshot.statistics('lineno')
|
||||
# del timestep
|
||||
print("[ Top 10 ]")
|
||||
for stat in top_stats[:10]:
|
||||
print(stat)
|
||||
# logging.info(f"Images: {obs}")
|
||||
32
realman_src/realman_aloha/shadow_rm_act/test/test_h5.py
Normal file
32
realman_src/realman_aloha/shadow_rm_act/test/test_h5.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import os
|
||||
# import time
|
||||
import yaml
|
||||
import torch
|
||||
import pickle
|
||||
import dm_env
|
||||
import logging
|
||||
import collections
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import transforms
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_act.models.latent_model import Latent_Model_Transformer
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.utils.utils import (
|
||||
load_data,
|
||||
sample_box_pose,
|
||||
sample_insertion_pose,
|
||||
compute_dict_mean,
|
||||
set_seed,
|
||||
detach_dict,
|
||||
)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
print('daasdas')
|
||||
147
realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py
Normal file
147
realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import h5py
|
||||
import argparse
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from constants import DT
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
STATE_NAMES = JOINT_NAMES + ["gripper"]
|
||||
|
||||
def load_hdf5(dataset_dir, dataset_name):
|
||||
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
|
||||
if not os.path.isfile(dataset_path):
|
||||
print(f'Dataset does not exist at \n{dataset_path}\n')
|
||||
exit()
|
||||
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
is_sim = root.attrs['sim']
|
||||
qpos = root['/observations/qpos'][()]
|
||||
qvel = root['/observations/qvel'][()]
|
||||
action = root['/action'][()]
|
||||
image_dict = dict()
|
||||
for cam_name in root[f'/observations/images/'].keys():
|
||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||
|
||||
return qpos, qvel, action, image_dict
|
||||
|
||||
def main(args):
|
||||
dataset_dir = args['dataset_dir']
|
||||
episode_idx = args['episode_idx']
|
||||
dataset_name = f'episode_{episode_idx}'
|
||||
|
||||
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
|
||||
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
||||
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
|
||||
|
||||
|
||||
def save_videos(video, dt, video_path=None):
|
||||
if isinstance(video, list):
|
||||
cam_names = list(video[0].keys())
|
||||
h, w, _ = video[0][cam_names[0]].shape
|
||||
w = w * len(cam_names)
|
||||
fps = int(1/dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for ts, image_dict in enumerate(video):
|
||||
images = []
|
||||
for cam_name in cam_names:
|
||||
image = image_dict[cam_name]
|
||||
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||
images.append(image)
|
||||
images = np.concatenate(images, axis=1)
|
||||
out.write(images)
|
||||
out.release()
|
||||
print(f'Saved video to: {video_path}')
|
||||
elif isinstance(video, dict):
|
||||
cam_names = list(video.keys())
|
||||
all_cam_videos = []
|
||||
for cam_name in cam_names:
|
||||
all_cam_videos.append(video[cam_name])
|
||||
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
||||
|
||||
n_frames, h, w, _ = all_cam_videos.shape
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for t in range(n_frames):
|
||||
image = all_cam_videos[t]
|
||||
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||
out.write(image)
|
||||
out.release()
|
||||
print(f'Saved video to: {video_path}')
|
||||
|
||||
|
||||
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
|
||||
if label_overwrite:
|
||||
label1, label2 = label_overwrite
|
||||
else:
|
||||
label1, label2 = 'State', 'Command'
|
||||
|
||||
qpos = np.array(qpos_list) # ts, dim
|
||||
command = np.array(command_list)
|
||||
num_ts, num_dim = qpos.shape
|
||||
h, w = 2, num_dim
|
||||
num_figs = num_dim
|
||||
fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs))
|
||||
|
||||
# plot joint state
|
||||
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(qpos[:, dim_idx], label=label1)
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
|
||||
# plot arm command
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(command[:, dim_idx], label=label2)
|
||||
ax.legend()
|
||||
|
||||
if ylim:
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved qpos plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def visualize_timestamp(t_list, dataset_path):
|
||||
plot_path = dataset_path.replace('.pkl', '_timestamp.png')
|
||||
h, w = 4, 10
|
||||
fig, axs = plt.subplots(2, 1, figsize=(w, h*2))
|
||||
# process t_list
|
||||
t_float = []
|
||||
for secs, nsecs in t_list:
|
||||
t_float.append(secs + nsecs * 10E-10)
|
||||
t_float = np.array(t_float)
|
||||
|
||||
ax = axs[0]
|
||||
ax.plot(np.arange(len(t_float)), t_float)
|
||||
ax.set_title(f'Camera frame timestamps')
|
||||
ax.set_xlabel('timestep')
|
||||
ax.set_ylabel('time (sec)')
|
||||
|
||||
ax = axs[1]
|
||||
ax.plot(np.arange(len(t_float)-1), t_float[:-1] - t_float[1:])
|
||||
ax.set_title(f'dt')
|
||||
ax.set_xlabel('timestep')
|
||||
ax.set_ylabel('time (sec)')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved timestamp plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
|
||||
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', required=False)
|
||||
main(vars(parser.parse_args()))
|
||||
10
realman_src/realman_aloha/shadow_rm_aloha/.gitignore
vendored
Normal file
10
realman_src/realman_aloha/shadow_rm_aloha/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
__pycache__/
|
||||
build/
|
||||
devel/
|
||||
dist/
|
||||
data/
|
||||
.catkin_workspace
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
.vscode/
|
||||
3
realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore
generated
vendored
Normal file
3
realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
1
realman_src/realman_aloha/shadow_rm_aloha/.idea/.name
generated
Normal file
1
realman_src/realman_aloha/shadow_rm_aloha/.idea/.name
generated
Normal file
@@ -0,0 +1 @@
|
||||
aloha_data_synchronizer.py
|
||||
17
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
17
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,17 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="4">
|
||||
<item index="0" class="java.lang.String" itemvalue="tensorboard" />
|
||||
<item index="1" class="java.lang.String" itemvalue="thop" />
|
||||
<item index="2" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="3" class="java.lang.String" itemvalue="torchvision" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml
generated
Normal file
7
realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.11 (随箱软件)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (随箱软件)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml
generated
Normal file
8
realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" filepath="$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
12
realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml
generated
Normal file
12
realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml
generated
Normal file
@@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
0
realman_src/realman_aloha/shadow_rm_aloha/README.md
Normal file
0
realman_src/realman_aloha/shadow_rm_aloha/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset'
|
||||
dataset_name: 'episode'
|
||||
max_timesteps: 500
|
||||
state_dim: 14
|
||||
overwrite: False
|
||||
arm_axis: 6
|
||||
camera_names:
|
||||
- 'cam_high'
|
||||
- 'cam_low'
|
||||
- 'cam_left'
|
||||
- 'cam_right'
|
||||
ros_topics:
|
||||
camera_left: '/camera_left/rgb/image_raw'
|
||||
camera_right: '/camera_right/rgb/image_raw'
|
||||
camera_bottom: '/camera_bottom/rgb/image_raw'
|
||||
camera_head: '/camera_head/rgb/image_raw'
|
||||
left_master_arm: '/left_master_arm_joint_states'
|
||||
left_slave_arm: '/left_slave_arm_joint_states'
|
||||
right_master_arm: '/right_master_arm_joint_states'
|
||||
right_slave_arm: '/right_slave_arm_joint_states'
|
||||
left_aloha_state: '/left_slave_arm_aloha_state'
|
||||
right_aloha_state: '/right_slave_arm_aloha_state'
|
||||
robot_env: {
|
||||
# TODO change the path to the correct one
|
||||
rm_left_arm: '/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml',
|
||||
rm_right_arm: '/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml',
|
||||
arm_axis: 6,
|
||||
head_camera: '241122071186',
|
||||
bottom_camera: '152122078546',
|
||||
left_camera: '150622070125',
|
||||
right_camera: '151222072576',
|
||||
init_left_arm_angle: [7.235, 31.816, 51.237, 2.463, 91.054, 12.04, 0.0],
|
||||
init_right_arm_angle: [-6.155, 33.925, 62.137, -1.672, 87.892, -3.868, 0.0]
|
||||
# init_left_arm_angle: [6.681, 38.496, 66.093, -1.141, 74.529, 3.076, 0.0],
|
||||
# init_right_arm_angle: [-4.79, 37.062, 72.393, -0.477, 68.593, -9.526, 0.0]
|
||||
# init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
|
||||
# init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
arm_ip: "192.168.1.18"
|
||||
arm_port: 8080
|
||||
arm_axis: 6
|
||||
local_ip: "192.168.1.101"
|
||||
local_port: 8089
|
||||
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||
get_vel: True
|
||||
get_torque: True
|
||||
@@ -0,0 +1,9 @@
|
||||
arm_ip: "192.168.1.19"
|
||||
arm_port: 8080
|
||||
arm_axis: 6
|
||||
local_ip: "192.168.1.101"
|
||||
local_port: 8090
|
||||
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||
get_vel: True
|
||||
get_torque: True
|
||||
@@ -0,0 +1,4 @@
|
||||
port: /dev/ttyUSB1
|
||||
baudrate: 460800
|
||||
hex_data: "55 AA 02 00 00 67"
|
||||
arm_axis: 6
|
||||
@@ -0,0 +1,4 @@
|
||||
port: /dev/ttyUSB0
|
||||
baudrate: 460800
|
||||
hex_data: "55 AA 02 00 00 67"
|
||||
arm_axis: 6
|
||||
@@ -0,0 +1,6 @@
|
||||
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset/20250102'
|
||||
dataset_name: 'episode'
|
||||
episode_idx: 1
|
||||
FPS: 30
|
||||
# joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate", "J7"] # 7 joints
|
||||
joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] # 6 joints
|
||||
39
realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml
Normal file
39
realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_rm_aloha"
|
||||
version = "0.1.1"
|
||||
description = "aloha package, use D435 and Realman robot arm to build aloha to collect data"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10"
|
||||
matplotlib = ">=3.9.2"
|
||||
h5py = ">=3.12.1"
|
||||
# rospy = ">=1.17.0"
|
||||
# shadow_rm_robot = { git = "https://github.com/Shadow2223/shadow_rm_robot.git", branch = "main" }
|
||||
# shadow_camera = { git = "https://github.com/Shadow2223/shadow_camera.git", branch = "main" }
|
||||
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,42 @@
|
||||
cmake_minimum_required(VERSION 3.0.2)
|
||||
project(shadow_rm_aloha)
|
||||
|
||||
find_package(catkin REQUIRED COMPONENTS
|
||||
rospy
|
||||
sensor_msgs
|
||||
cv_bridge
|
||||
image_transport
|
||||
std_msgs
|
||||
message_generation
|
||||
)
|
||||
|
||||
add_service_files(
|
||||
FILES
|
||||
GetArmStatus.srv
|
||||
GetImage.srv
|
||||
MoveArm.srv
|
||||
)
|
||||
|
||||
generate_messages(
|
||||
DEPENDENCIES
|
||||
sensor_msgs
|
||||
std_msgs
|
||||
)
|
||||
|
||||
catkin_package(
|
||||
CATKIN_DEPENDS message_runtime rospy std_msgs
|
||||
)
|
||||
|
||||
include_directories(
|
||||
${catkin_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
install(PROGRAMS
|
||||
arm_node/slave_arm_publisher.py
|
||||
arm_node/master_arm_publisher.py
|
||||
arm_node/slave_arm_pub_sub.py
|
||||
camera_node/camera_publisher.py
|
||||
data_sub_process/aloha_data_synchronizer.py
|
||||
data_sub_process/aloha_data_collect.py
|
||||
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
import logging
|
||||
from shadow_rm_robot.servo_robotic_arm import ServoArm
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
class MasterArmPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node("master_arm_publisher", anonymous=True)
|
||||
arm_config = rospy.get_param("~arm_config","config/servo_left_arm.yaml")
|
||||
hz = rospy.get_param("~hz", 250)
|
||||
self.joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||
|
||||
self.arm = ServoArm(arm_config)
|
||||
self.publisher = rospy.Publisher(self.joint_states_topic, JointState, queue_size=1)
|
||||
self.rate = rospy.Rate(hz) # 30 Hz
|
||||
|
||||
def publish_joint_states(self):
|
||||
while not rospy.is_shutdown():
|
||||
joint_state = JointState()
|
||||
joint_pos = self.arm.get_joint_actions()
|
||||
|
||||
joint_state.header.stamp = rospy.Time.now()
|
||||
joint_state.name = list(joint_pos.keys())
|
||||
joint_state.position = list(joint_pos.values())
|
||||
joint_state.velocity = [0.0] * len(joint_pos) # 速度(可根据需要修改)
|
||||
joint_state.effort = [0.0] * len(joint_pos) # 力矩(可根据需要修改)
|
||||
|
||||
# rospy.loginfo(f"{self.joint_states_topic}: {joint_state}")
|
||||
self.publisher.publish(joint_state)
|
||||
self.rate.sleep()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
arm_publisher = MasterArmPublisher()
|
||||
arm_publisher.publish_joint_states()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
import logging
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from sensor_msgs.msg import JointState
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
class SlaveArmPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node("slave_arm_publisher", anonymous=True)
|
||||
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
|
||||
hz = rospy.get_param("~hz", 250)
|
||||
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||
joint_actions_topic = rospy.get_param("~joint_actions_topic", "/joint_actions")
|
||||
self.arm = RmArm(arm_config)
|
||||
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
|
||||
self.subscriber = rospy.Subscriber(joint_actions_topic, JointState, self.callback)
|
||||
self.rate = rospy.Rate(hz)
|
||||
|
||||
def publish_joint_states(self):
|
||||
while not rospy.is_shutdown():
|
||||
joint_state = JointState()
|
||||
data = self.arm.get_integrate_data()
|
||||
# data = self.arm.get_arm_data()
|
||||
joint_state.header.stamp = rospy.Time.now()
|
||||
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"]
|
||||
# joint_state.position = data["joint_angle"]
|
||||
joint_state.position = data['arm_angle']
|
||||
|
||||
# joint_state.position = list(data["arm_angle"])
|
||||
# joint_state.velocity = list(data["arm_velocity"])
|
||||
# joint_state.effort = list(data["arm_torque"])
|
||||
# rospy.loginfo(f"joint_states_topic: {joint_state}")
|
||||
self.publisher.publish(joint_state)
|
||||
self.rate.sleep()
|
||||
|
||||
def callback(self, data):
|
||||
# rospy.loginfo(f"Received joint_states_topic: {data}")
|
||||
start_time = rospy.Time.now()
|
||||
if data is None:
|
||||
return
|
||||
if data.name == ["joint_canfd"]:
|
||||
self.arm.set_joint_canfd_position(data.position[0:6])
|
||||
elif data.name == ["joint_j"]:
|
||||
self.arm.set_joint_position(data.position[0:6])
|
||||
|
||||
# self.arm.set_gripper_position(data.position[6])
|
||||
end_time = rospy.Time.now()
|
||||
time_cost_ms = (end_time - start_time).to_sec() * 1000
|
||||
rospy.loginfo(f"Time cost: {data.name},{time_cost_ms}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
arm_publisher = SlaveArmPublisher()
|
||||
arm_publisher.publish_joint_states()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
import logging
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from sensor_msgs.msg import JointState
|
||||
from std_msgs.msg import Int32MultiArray
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
class SlaveArmPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node("slave_arm_publisher", anonymous=True)
|
||||
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
|
||||
hz = rospy.get_param("~hz", 250)
|
||||
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||
aloha_state_topic = rospy.get_param("~aloha_state_topic", "/aloha_state")
|
||||
self.arm = RmArm(arm_config)
|
||||
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
|
||||
self.aloha_state_pub = rospy.Publisher(aloha_state_topic, Int32MultiArray, queue_size=1)
|
||||
self.rate = rospy.Rate(hz)
|
||||
|
||||
def publish_joint_states(self):
|
||||
while not rospy.is_shutdown():
|
||||
joint_state = JointState()
|
||||
data = self.arm.get_integrate_data()
|
||||
joint_state.header.stamp = rospy.Time.now()
|
||||
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6", "joint7"]
|
||||
joint_state.position = list(data["arm_angle"])
|
||||
joint_state.velocity = list(data["arm_velocity"])
|
||||
joint_state.effort = list(data["arm_torque"])
|
||||
# rospy.loginfo(f"joint_states_topic: {joint_state}")
|
||||
self.aloha_state_pub.publish(Int32MultiArray(data=data["aloha_state"].values()))
|
||||
self.publisher.publish(joint_state)
|
||||
self.rate.sleep()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
arm_publisher = SlaveArmPublisher()
|
||||
arm_publisher.publish_joint_states()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
from sensor_msgs.msg import Image
|
||||
from std_msgs.msg import Header
|
||||
import numpy as np
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
|
||||
class CameraPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node('camera_publisher', anonymous=True)
|
||||
self.serial_number = rospy.get_param('~serial_number', None)
|
||||
hz = rospy.get_param('~hz', 30)
|
||||
rospy.loginfo(f"Serial number: {self.serial_number}")
|
||||
|
||||
self.rgb_topic = rospy.get_param('~rgb_topic', '/camera/rgb/image_raw')
|
||||
self.depth_topic = rospy.get_param('~depth_topic', '/camera/depth/image_raw')
|
||||
rospy.loginfo(f"RGB topic: {self.rgb_topic}")
|
||||
rospy.loginfo(f"Depth topic: {self.depth_topic}")
|
||||
|
||||
self.rgb_pub = rospy.Publisher(self.rgb_topic, Image, queue_size=10)
|
||||
# self.depth_pub = rospy.Publisher(self.depth_topic, Image, queue_size=10)
|
||||
self.rate = rospy.Rate(hz) # 30 Hz
|
||||
self.camera = RealSenseCamera(self.serial_number, False)
|
||||
|
||||
rospy.loginfo("Camera initialized")
|
||||
|
||||
def publish_images(self):
|
||||
self.camera.start_camera()
|
||||
rospy.loginfo("Camera started")
|
||||
while not rospy.is_shutdown():
|
||||
result = self.camera.read_frame(True, False, False, False)
|
||||
if result is None:
|
||||
rospy.logerr("Failed to read frame from camera")
|
||||
continue
|
||||
|
||||
color_image, depth_image, _, _ = result
|
||||
|
||||
if color_image is not None or depth_image is not None:
|
||||
rgb_msg = self.create_image_msg(color_image, "bgr8")
|
||||
# depth_msg = self.create_image_msg(depth_image, "mono16")
|
||||
|
||||
self.rgb_pub.publish(rgb_msg)
|
||||
# self.depth_pub.publish(depth_msg)
|
||||
# rospy.loginfo("Published RGB image")
|
||||
else:
|
||||
rospy.logwarn("Received None for color_image or depth_image")
|
||||
|
||||
self.rate.sleep()
|
||||
self.camera.stop_camera()
|
||||
rospy.loginfo("Camera stopped")
|
||||
|
||||
def create_image_msg(self, image, encoding):
|
||||
msg = Image()
|
||||
msg.header = Header()
|
||||
msg.header.stamp = rospy.Time.now()
|
||||
msg.height, msg.width = image.shape[:2]
|
||||
msg.encoding = encoding
|
||||
msg.is_bigendian = False
|
||||
msg.step = image.strides[0]
|
||||
msg.data = np.array(image).tobytes()
|
||||
return msg
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
camera_publisher = CameraPublisher()
|
||||
camera_publisher.publish_images()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,112 @@
|
||||
import os
|
||||
import time
|
||||
import h5py
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
class DataCollector:
|
||||
def __init__(self, dataset_dir, dataset_name, max_timesteps, camera_names, state_dim, overwrite=False):
|
||||
self.arm_axis = 7
|
||||
self.dataset_dir = dataset_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_timesteps = max_timesteps
|
||||
self.camera_names = camera_names
|
||||
self.state_dim = state_dim
|
||||
self.overwrite = overwrite
|
||||
self.data_dict = {
|
||||
'/observations/qpos': [],
|
||||
'/observations/qvel': [],
|
||||
'/observations/effort': [],
|
||||
'/action': [],
|
||||
}
|
||||
for cam_name in camera_names:
|
||||
self.data_dict[f'/observations/images/{cam_name}'] = []
|
||||
|
||||
# 自动检测和创建数据集目录
|
||||
self.create_dataset_dir()
|
||||
self.timesteps_collected = 0
|
||||
|
||||
|
||||
def create_dataset_dir(self):
|
||||
# 按照年月日创建目录
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||
if not os.path.exists(self.dataset_dir):
|
||||
os.makedirs(self.dataset_dir)
|
||||
|
||||
def collect_data(self, ts, action):
|
||||
self.data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||
self.data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||
self.data_dict['/observations/effort'].append(ts.observation['effort'])
|
||||
self.data_dict['/action'].append(action)
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||
|
||||
def save_data(self):
|
||||
t0 = time.time()
|
||||
# 保存数据
|
||||
with h5py.File(self.dataset_path, mode='w', rdcc_nbytes=1024**2*2) as root:
|
||||
root.attrs['sim'] = False
|
||||
obs = root.create_group('observations')
|
||||
image = obs.create_group('images')
|
||||
for cam_name in self.camera_names:
|
||||
_ = image.create_dataset(cam_name, (self.max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3))
|
||||
_ = obs.create_dataset('qpos', (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset('qvel', (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset('effort', (self.max_timesteps, self.state_dim))
|
||||
_ = root.create_dataset('action', (self.max_timesteps, self.state_dim))
|
||||
|
||||
for name, array in self.data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f'Saving: {time.time() - t0:.1f} secs')
|
||||
return True
|
||||
|
||||
def load_hdf5(self, orign_path, file):
|
||||
self.dataset_path = os.path.join(self.dataset_dir, file)
|
||||
if not os.path.isfile(orign_path):
|
||||
logging.error(f'Dataset does not exist at {orign_path}')
|
||||
exit()
|
||||
|
||||
with h5py.File(orign_path, 'r') as root:
|
||||
self.is_sim = root.attrs['sim']
|
||||
self.qpos = root['/observations/qpos'][()]
|
||||
self.qvel = root['/observations/qvel'][()]
|
||||
self.effort = root['/observations/effort'][()]
|
||||
self.action = root['/action'][()]
|
||||
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||
for cam_name in root[f'/observations/images/'].keys()}
|
||||
|
||||
self.qpos[:, self.arm_axis] = self.action[:, self.arm_axis]
|
||||
self.qpos[:, self.arm_axis*2+1] = self.action[:, self.arm_axis*2+1]
|
||||
|
||||
self.data_dict['/observations/qpos'] = self.qpos
|
||||
self.data_dict['/observations/qvel'] = self.qvel
|
||||
self.data_dict['/observations/effort'] = self.effort
|
||||
self.data_dict['/action'] = self.action
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f'/observations/images/{cam_name}'] = self.image_dict[cam_name]
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
用于更改夹爪数据,将从臂夹爪数据更改为主笔夹爪数据
|
||||
|
||||
"""
|
||||
dataset_dir = '/home/wang/project/shadow_rm_aloha/data'
|
||||
orign_dir = '/home/wang/project/shadow_rm_aloha/data/dataset/20241128'
|
||||
dataset_name = 'test'
|
||||
max_timesteps = 300
|
||||
camera_names = ['cam_high','cam_low','cam_left','cam_right']
|
||||
state_dim = 16
|
||||
collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
|
||||
for file in os.listdir(orign_dir):
|
||||
collector.__init__(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
|
||||
orign_path = os.path.join(orign_dir, file)
|
||||
print(orign_path)
|
||||
collector.load_hdf5(orign_path, file)
|
||||
collector.save_data()
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import h5py
|
||||
import yaml
|
||||
import logging
|
||||
import time
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
class DataValidator:
|
||||
def __init__(self, config):
|
||||
self.dataset_dir = config['dataset_dir']
|
||||
self.episode_idx = config['episode_idx']
|
||||
self.joint_names = config['joint_names']
|
||||
self.dataset_name = f'episode_{self.episode_idx}'
|
||||
self.dataset_path = os.path.join(self.dataset_dir, self.dataset_name + '.hdf5')
|
||||
self.state_names = self.joint_names + ["gripper"]
|
||||
self.arm = RmArm('/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml')
|
||||
|
||||
def load_hdf5(self):
|
||||
if not os.path.isfile(self.dataset_path):
|
||||
logging.error(f'Dataset does not exist at {self.dataset_path}')
|
||||
exit()
|
||||
|
||||
with h5py.File(self.dataset_path, 'r') as root:
|
||||
self.is_sim = root.attrs['sim']
|
||||
self.qpos = root['/observations/qpos'][()]
|
||||
# self.qvel = root['/observations/qvel'][()]
|
||||
# self.effort = root['/observations/effort'][()]
|
||||
self.action = root['/action'][()]
|
||||
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||
for cam_name in root[f'/observations/images/'].keys()}
|
||||
|
||||
def validate_data(self):
|
||||
# 验证位置数据
|
||||
if not self.qpos.shape[1] == 14:
|
||||
logging.error('qpos shape does not match expected number of joints')
|
||||
return False
|
||||
|
||||
logging.info('Data validation passed')
|
||||
return True
|
||||
|
||||
def control_robot(self):
|
||||
self.arm.set_joint_position(self.qpos[0][0:6])
|
||||
for qpos in self.qpos:
|
||||
logging.info(f'qpos: {qpos}')
|
||||
self.arm.set_joint_canfd_position(qpos[7:13])
|
||||
self.arm.set_gripper_position(qpos[13])
|
||||
time.sleep(0.035)
|
||||
|
||||
|
||||
def run(self):
|
||||
self.load_hdf5()
|
||||
if self.validate_data():
|
||||
self.control_robot()
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, 'r') as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
|
||||
validator = DataValidator(config)
|
||||
validator.run()
|
||||
@@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import h5py
|
||||
import yaml
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
class DataVisualizer:
|
||||
def __init__(self, config):
|
||||
self.dataset_dir = config['dataset_dir']
|
||||
self.episode_idx = config['episode_idx']
|
||||
self.dt = 1/config['FPS']
|
||||
self.joint_names = config['joint_names']
|
||||
self.state_names = self.joint_names + ["gripper"]
|
||||
# self.camera_names = config['camera_names']
|
||||
|
||||
def join_file_path(self, file_name):
|
||||
self.dataset_path = os.path.join(self.dataset_dir, file_name)
|
||||
|
||||
def load_hdf5(self):
|
||||
if not os.path.isfile(self.dataset_path):
|
||||
logging.error(f'Dataset does not exist at {self.dataset_path}')
|
||||
exit()
|
||||
|
||||
with h5py.File(self.dataset_path, 'r') as root:
|
||||
self.is_sim = root.attrs['sim']
|
||||
self.qpos = root['/observations/qpos'][()]
|
||||
# self.qvel = root['/observations/qvel'][()]
|
||||
# self.effort = root['/observations/effort'][()]
|
||||
self.action = root['/action'][()]
|
||||
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||
for cam_name in root[f'/observations/images/'].keys()}
|
||||
|
||||
def save_videos(self, video, dt, video_path=None):
|
||||
if isinstance(video, list):
|
||||
cam_names = list(video[0].keys())
|
||||
h, w, _ = video[0][cam_names[0]].shape
|
||||
w = w * len(cam_names)
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for image_dict in video:
|
||||
images = [image_dict[cam_name][:, :, [2, 1, 0]] for cam_name in cam_names]
|
||||
out.write(np.concatenate(images, axis=1))
|
||||
out.release()
|
||||
logging.info(f'Saved video to: {video_path}')
|
||||
elif isinstance(video, dict):
|
||||
cam_names = list(video.keys())
|
||||
all_cam_videos = np.concatenate([video[cam_name] for cam_name in cam_names], axis=2)
|
||||
n_frames, h, w, _ = all_cam_videos.shape
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for t in range(n_frames):
|
||||
out.write(all_cam_videos[t][:, :, [2, 1, 0]])
|
||||
out.release()
|
||||
logging.info(f'Saved video to: {video_path}')
|
||||
|
||||
def visualize_joints(self, qpos_list, command_list, plot_path, ylim=None, label_overwrite=None):
|
||||
label1, label2 = label_overwrite if label_overwrite else ('State', 'Command')
|
||||
qpos = np.array(qpos_list)
|
||||
command = np.array(command_list)
|
||||
num_ts, num_dim = qpos.shape
|
||||
logging.info(f'qpos shape: {qpos.shape}, command shape: {command.shape}')
|
||||
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
|
||||
|
||||
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(qpos[:, dim_idx], label=label1)
|
||||
ax.plot(command[:, dim_idx], label=label2)
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
if ylim:
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
logging.info(f'Saved qpos plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def visualize_single(self, data_list, label, plot_path, ylim=None):
|
||||
data = np.array(data_list)
|
||||
num_ts, num_dim = data.shape
|
||||
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
|
||||
|
||||
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(data[:, dim_idx], label=label)
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
if ylim:
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
logging.info(f'Saved {label} plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def visualize_timestamp(self, t_list):
|
||||
plot_path = self.dataset_path.replace('.hdf5', '_timestamp.png')
|
||||
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
|
||||
t_float = np.array([secs + nsecs * 1e-9 for secs, nsecs in t_list])
|
||||
|
||||
axs[0].plot(np.arange(len(t_float)), t_float)
|
||||
axs[0].set_title('Camera frame timestamps')
|
||||
axs[0].set_xlabel('timestep')
|
||||
axs[0].set_ylabel('time (sec)')
|
||||
|
||||
axs[1].plot(np.arange(len(t_float) - 1), t_float[:-1] - t_float[1:])
|
||||
axs[1].set_title('dt')
|
||||
axs[1].set_xlabel('timestep')
|
||||
axs[1].set_ylabel('time (sec)')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
logging.info(f'Saved timestamp plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def run(self):
|
||||
for file_name in os.listdir(self.dataset_dir):
|
||||
if file_name.endswith('.hdf5'):
|
||||
self.join_file_path(file_name)
|
||||
self.load_hdf5()
|
||||
video_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_video.mp4'))
|
||||
self.save_videos(self.image_dict, self.dt, video_path)
|
||||
qpos_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_qpos.png'))
|
||||
self.visualize_joints(self.qpos, self.action, qpos_plot_path)
|
||||
# effort_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_effort.png'))
|
||||
# self.visualize_single(self.effort, 'effort', effort_plot_path)
|
||||
# error_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_error.png'))
|
||||
# self.visualize_single(self.action - self.qpos, 'tracking_error', error_plot_path)
|
||||
# self.visualize_timestamp(t_list) # TODO: Add timestamp visualization back
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, 'r') as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
|
||||
visualizer = DataVisualizer(config)
|
||||
visualizer.run()
|
||||
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
import time
|
||||
import yaml
|
||||
import rospy
|
||||
import dm_env
|
||||
import numpy as np
|
||||
import collections
|
||||
from datetime import datetime
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||
|
||||
|
||||
class DataSynchronizer:
|
||||
def __init__(self, config_path="config"):
|
||||
rospy.init_node("synchronizer", anonymous=True)
|
||||
|
||||
with open(config_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
self.init_left_arm_angle = config["robot_env"]["init_left_arm_angle"]
|
||||
self.init_right_arm_angle = config["robot_env"]["init_right_arm_angle"]
|
||||
self.arm_axis = config["robot_env"]["arm_axis"]
|
||||
self.camera_names = config["camera_names"]
|
||||
|
||||
# 创建订阅者
|
||||
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||
self.camera_bottom_sub = Subscriber(
|
||||
config["ros_topics"]["camera_bottom"], Image
|
||||
)
|
||||
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||
|
||||
self.left_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_slave_arm_sub"], JointState
|
||||
)
|
||||
self.right_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_slave_arm_sub"], JointState
|
||||
)
|
||||
self.left_slave_arm_pub = rospy.Publisher(
|
||||
config["ros_topics"]["left_slave_arm_pub"], JointState, queue_size=1
|
||||
)
|
||||
self.right_slave_arm_pub = rospy.Publisher(
|
||||
config["ros_topics"]["right_slave_arm_pub"], JointState, queue_size=1
|
||||
)
|
||||
|
||||
# 创建同步器
|
||||
self.ats = ApproximateTimeSynchronizer(
|
||||
[
|
||||
self.camera_left_sub,
|
||||
self.camera_right_sub,
|
||||
self.camera_bottom_sub,
|
||||
self.camera_head_sub,
|
||||
self.left_slave_arm_sub,
|
||||
self.right_slave_arm_sub,
|
||||
],
|
||||
queue_size=1,
|
||||
slop=0.1,
|
||||
)
|
||||
|
||||
self.ats.registerCallback(self.callback)
|
||||
self.ts = None
|
||||
self.is_frist_step = True
|
||||
|
||||
def callback(
|
||||
self,
|
||||
camera_left_img,
|
||||
camera_right_img,
|
||||
camera_bottom_img,
|
||||
camera_head_img,
|
||||
left_slave_arm,
|
||||
right_slave_arm,
|
||||
):
|
||||
|
||||
# 将ROS图像消息转换为NumPy数组
|
||||
camera_left_np_img = np.frombuffer(
|
||||
camera_left_img.data, dtype=np.uint8
|
||||
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||
camera_right_np_img = np.frombuffer(
|
||||
camera_right_img.data, dtype=np.uint8
|
||||
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||
camera_bottom_np_img = np.frombuffer(
|
||||
camera_bottom_img.data, dtype=np.uint8
|
||||
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||
camera_head_np_img = np.frombuffer(
|
||||
camera_head_img.data, dtype=np.uint8
|
||||
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||
|
||||
left_slave_arm_angle = left_slave_arm.position
|
||||
left_slave_arm_velocity = left_slave_arm.velocity
|
||||
left_slave_arm_force = left_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||
|
||||
right_slave_arm_angle = right_slave_arm.position
|
||||
right_slave_arm_velocity = right_slave_arm.velocity
|
||||
right_slave_arm_force = right_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||
|
||||
# 收集数据
|
||||
obs = collections.OrderedDict(
|
||||
{
|
||||
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||
"qvel": np.concatenate(
|
||||
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||
),
|
||||
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||
"images": {
|
||||
self.camera_names[0]: camera_head_np_img,
|
||||
self.camera_names[1]: camera_bottom_np_img,
|
||||
self.camera_names[2]: camera_left_np_img,
|
||||
self.camera_names[3]: camera_right_np_img,
|
||||
},
|
||||
}
|
||||
)
|
||||
self.ts = dm_env.TimeStep(
|
||||
step_type=(
|
||||
dm_env.StepType.FIRST if self.is_frist_step else dm_env.StepType.MID
|
||||
),
|
||||
reward=0.0,
|
||||
discount=1.0,
|
||||
observation=obs,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
|
||||
left_joint_state = JointState()
|
||||
left_joint_state.header.stamp = rospy.Time.now()
|
||||
left_joint_state.name = ["joint_j"]
|
||||
left_joint_state.position = self.init_left_arm_angle[0 : self.arm_axis + 1]
|
||||
|
||||
right_joint_state = JointState()
|
||||
right_joint_state.header.stamp = rospy.Time.now()
|
||||
right_joint_state.name = ["joint_j"]
|
||||
right_joint_state.position = self.init_right_arm_angle[0 : self.arm_axis + 1]
|
||||
|
||||
self.left_slave_arm_pub.publish(left_joint_state)
|
||||
self.right_slave_arm_pub.publish(right_joint_state)
|
||||
while self.ts is None:
|
||||
time.sleep(0.002)
|
||||
|
||||
return self.ts
|
||||
|
||||
|
||||
|
||||
def step(self, target_angle):
|
||||
self.is_frist_step = False
|
||||
left_joint_state = JointState()
|
||||
left_joint_state.header.stamp = rospy.Time.now()
|
||||
left_joint_state.name = ["joint_canfd"]
|
||||
left_joint_state.position = target_angle[0 : self.arm_axis + 1]
|
||||
# print("left_joint_state: ", left_joint_state)
|
||||
|
||||
right_joint_state = JointState()
|
||||
right_joint_state.header.stamp = rospy.Time.now()
|
||||
right_joint_state.name = ["joint_canfd"]
|
||||
right_joint_state.position = target_angle[self.arm_axis + 1 : (self.arm_axis + 1) * 2]
|
||||
# print("right_joint_state: ", right_joint_state)
|
||||
|
||||
self.left_slave_arm_pub.publish(left_joint_state)
|
||||
self.right_slave_arm_pub.publish(right_joint_state)
|
||||
# time.sleep(0.013)
|
||||
return self.ts
|
||||
|
||||
def run(self):
|
||||
rospy.loginfo("Starting ROS spin")
|
||||
data = np.concatenate([self.init_left_arm_angle, self.init_right_arm_angle])
|
||||
self.reset()
|
||||
# print("data: ", data)
|
||||
while not rospy.is_shutdown():
|
||||
self.step(data)
|
||||
rospy.sleep(0.010)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
synchronizer = DataSynchronizer("/home/rm/code/shadow_act/config/config.yaml")
|
||||
start_time = time.time()
|
||||
synchronizer.run()
|
||||
@@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import time
|
||||
import h5py
|
||||
import yaml
|
||||
import rospy
|
||||
import dm_env
|
||||
import numpy as np
|
||||
import collections
|
||||
from datetime import datetime
|
||||
from std_msgs.msg import Int32MultiArray
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||
|
||||
|
||||
class DataCollector:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir,
|
||||
dataset_name,
|
||||
max_timesteps,
|
||||
camera_names,
|
||||
state_dim,
|
||||
overwrite=False,
|
||||
):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_timesteps = max_timesteps
|
||||
self.camera_names = camera_names
|
||||
self.state_dim = state_dim
|
||||
self.overwrite = overwrite
|
||||
self.init_dict()
|
||||
self.create_dataset_dir()
|
||||
|
||||
def init_dict(self):
|
||||
self.data_dict = {
|
||||
"/observations/qpos": [],
|
||||
"/observations/qvel": [],
|
||||
"/observations/effort": [],
|
||||
"/action": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"] = []
|
||||
|
||||
def create_dataset_dir(self):
|
||||
# 按照年月日创建目录
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||
if not os.path.exists(self.dataset_dir):
|
||||
os.makedirs(self.dataset_dir)
|
||||
|
||||
def create_file(self):
|
||||
# 检查数据集名称是否存在,如果存在则递增名称
|
||||
counter = 0
|
||||
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
|
||||
if not self.overwrite:
|
||||
while os.path.exists(dataset_path + ".hdf5"):
|
||||
dataset_path = os.path.join(
|
||||
self.dataset_dir, f"{self.dataset_name}_{counter}"
|
||||
)
|
||||
counter += 1
|
||||
self.dataset_path = dataset_path
|
||||
|
||||
def collect_data(self, ts, action):
|
||||
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
|
||||
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
|
||||
self.data_dict["/observations/effort"].append(ts.observation["effort"])
|
||||
self.data_dict["/action"].append(action)
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"].append(
|
||||
ts.observation["images"][cam_name]
|
||||
)
|
||||
|
||||
def save_data(self):
|
||||
self.create_file()
|
||||
t0 = time.time()
|
||||
# 保存数据
|
||||
with h5py.File(
|
||||
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
|
||||
) as root:
|
||||
root.attrs["sim"] = False
|
||||
obs = root.create_group("observations")
|
||||
image = obs.create_group("images")
|
||||
for cam_name in self.camera_names:
|
||||
_ = image.create_dataset(
|
||||
cam_name,
|
||||
(self.max_timesteps, 480, 640, 3),
|
||||
dtype="uint8",
|
||||
chunks=(1, 480, 640, 3),
|
||||
)
|
||||
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
|
||||
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
|
||||
|
||||
for name, array in self.data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f"Saving: {time.time() - t0:.1f} secs")
|
||||
return True
|
||||
|
||||
|
||||
class DataSynchronizer:
|
||||
def __init__(self, config_path="config"):
|
||||
rospy.init_node("synchronizer", anonymous=True)
|
||||
rospy.loginfo("ROS node initialized")
|
||||
with open(config_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
self.arm_axis = config["arm_axis"]
|
||||
# 创建订阅者
|
||||
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||
self.camera_bottom_sub = Subscriber(
|
||||
config["ros_topics"]["camera_bottom"], Image
|
||||
)
|
||||
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||
|
||||
self.left_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_master_arm"], JointState
|
||||
)
|
||||
self.left_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_slave_arm"], JointState
|
||||
)
|
||||
self.right_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_master_arm"], JointState
|
||||
)
|
||||
self.right_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_slave_arm"], JointState
|
||||
)
|
||||
self.left_aloha_state_pub = rospy.Subscriber(
|
||||
config["ros_topics"]["left_aloha_state"],
|
||||
Int32MultiArray,
|
||||
self.aloha_state_callback,
|
||||
)
|
||||
|
||||
rospy.loginfo("Subscribers created")
|
||||
self.camera_names = config["camera_names"]
|
||||
|
||||
# 创建同步器
|
||||
self.ats = ApproximateTimeSynchronizer(
|
||||
[
|
||||
self.camera_left_sub,
|
||||
self.camera_right_sub,
|
||||
self.camera_bottom_sub,
|
||||
self.camera_head_sub,
|
||||
self.left_master_arm_sub,
|
||||
self.left_slave_arm_sub,
|
||||
self.right_master_arm_sub,
|
||||
self.right_slave_arm_sub,
|
||||
],
|
||||
queue_size=1,
|
||||
slop=0.05,
|
||||
)
|
||||
self.ats.registerCallback(self.callback)
|
||||
rospy.loginfo("Time synchronizer created and callback registered")
|
||||
|
||||
self.data_collector = DataCollector(
|
||||
dataset_dir=config["dataset_dir"],
|
||||
dataset_name=config["dataset_name"],
|
||||
max_timesteps=config["max_timesteps"],
|
||||
camera_names=config["camera_names"],
|
||||
state_dim=config["state_dim"],
|
||||
overwrite=config["overwrite"],
|
||||
)
|
||||
self.timesteps_collected = 0
|
||||
self.begin_collect = False
|
||||
self.last_time = None
|
||||
|
||||
def callback(
|
||||
self,
|
||||
camera_left_img,
|
||||
camera_right_img,
|
||||
camera_bottom_img,
|
||||
camera_head_img,
|
||||
left_master_arm,
|
||||
left_slave_arm,
|
||||
right_master_arm,
|
||||
right_slave_arm,
|
||||
):
|
||||
if self.begin_collect:
|
||||
self.timesteps_collected += 1
|
||||
rospy.loginfo(
|
||||
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
|
||||
)
|
||||
else:
|
||||
self.timesteps_collected = 0
|
||||
return
|
||||
if self.timesteps_collected == 0:
|
||||
return
|
||||
current_time = time.time()
|
||||
if self.last_time is not None:
|
||||
frequency = 1.0 / (current_time - self.last_time)
|
||||
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
|
||||
self.last_time = current_time
|
||||
# 将ROS图像消息转换为NumPy数组
|
||||
camera_left_np_img = np.frombuffer(
|
||||
camera_left_img.data, dtype=np.uint8
|
||||
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||
camera_right_np_img = np.frombuffer(
|
||||
camera_right_img.data, dtype=np.uint8
|
||||
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||
camera_bottom_np_img = np.frombuffer(
|
||||
camera_bottom_img.data, dtype=np.uint8
|
||||
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||
camera_head_np_img = np.frombuffer(
|
||||
camera_head_img.data, dtype=np.uint8
|
||||
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||
|
||||
# 提取臂的角度,速度,力
|
||||
left_master_arm_angle = left_master_arm.position
|
||||
# left_master_arm_velocity = left_master_arm.velocity
|
||||
# left_master_arm_force = left_master_arm.effort
|
||||
|
||||
left_slave_arm_angle = left_slave_arm.position
|
||||
left_slave_arm_velocity = left_slave_arm.velocity
|
||||
left_slave_arm_force = left_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||
|
||||
right_master_arm_angle = right_master_arm.position
|
||||
# right_master_arm_velocity = right_master_arm.velocity
|
||||
# right_master_arm_force = right_master_arm.effort
|
||||
|
||||
right_slave_arm_angle = right_slave_arm.position
|
||||
right_slave_arm_velocity = right_slave_arm.velocity
|
||||
right_slave_arm_force = right_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||
|
||||
# 收集数据
|
||||
obs = collections.OrderedDict(
|
||||
{
|
||||
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||
"qvel": np.concatenate(
|
||||
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||
),
|
||||
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||
"images": {
|
||||
self.camera_names[0]: camera_head_np_img,
|
||||
self.camera_names[1]: camera_bottom_np_img,
|
||||
self.camera_names[2]: camera_left_np_img,
|
||||
self.camera_names[3]: camera_right_np_img,
|
||||
},
|
||||
}
|
||||
)
|
||||
print(self.camera_names[0])
|
||||
ts = dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
|
||||
)
|
||||
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
|
||||
self.data_collector.collect_data(ts, action)
|
||||
|
||||
# 检查是否收集了足够的数据
|
||||
if self.timesteps_collected >= self.data_collector.max_timesteps:
|
||||
self.data_collector.save_data()
|
||||
|
||||
rospy.loginfo("Data collection complete")
|
||||
|
||||
self.data_collector.init_dict()
|
||||
self.begin_collect = False
|
||||
self.timesteps_collected = 0
|
||||
|
||||
def aloha_state_callback(self, data):
|
||||
if not self.begin_collect:
|
||||
self.aloha_state = data.data
|
||||
print(self.aloha_state[0], self.aloha_state[1])
|
||||
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
|
||||
self.begin_collect = True
|
||||
|
||||
|
||||
def run(self):
|
||||
rospy.loginfo("Starting ROS spin")
|
||||
|
||||
rospy.spin()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
synchronizer = DataSynchronizer(
|
||||
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
|
||||
)
|
||||
synchronizer.run()
|
||||
@@ -0,0 +1,63 @@
|
||||
<launch>
|
||||
<!-- 左从臂节点 -->
|
||||
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||
<param name="aloha_state_topic" value="/left_slave_arm_aloha_state" type= "string"/>
|
||||
<param name="hz" value="120" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右从臂节点 -->
|
||||
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||
<param name="aloha_state_topic" value="/right_slave_arm_aloha_state" type= "string"/>
|
||||
<param name="hz" value="120" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左主臂节点 -->
|
||||
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右主臂节点 -->
|
||||
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右臂相机节点 -->
|
||||
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="151222072576" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左臂相机节点 -->
|
||||
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="150622070125" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 顶部相机节点 -->
|
||||
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="241122071186" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 底部相机节点 -->
|
||||
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="152122078546" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
</launch>
|
||||
@@ -0,0 +1,49 @@
|
||||
<launch>
|
||||
<!-- 左从臂节点 -->
|
||||
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||
<param name="joint_actions_topic" value="/left_slave_arm_joint_actions" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右从臂节点 -->
|
||||
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||
<param name="joint_actions_topic" value="/right_slave_arm_joint_actions" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右臂相机节点 -->
|
||||
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="151222072576" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左臂相机节点 -->
|
||||
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="150622070125" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 顶部相机节点 -->
|
||||
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="241122071186" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 底部相机节点 -->
|
||||
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="152122078546" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
</launch>
|
||||
@@ -0,0 +1,61 @@
|
||||
<launch>
|
||||
<!-- 左从臂节点 -->
|
||||
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右从臂节点 -->
|
||||
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左主臂节点 -->
|
||||
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右主臂节点 -->
|
||||
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右臂相机节点 -->
|
||||
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="216322070299" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左臂相机节点 -->
|
||||
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="216322074992" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 顶部相机节点 -->
|
||||
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="215322076086" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 底部相机节点 -->
|
||||
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="215222074360" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
</launch>
|
||||
@@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import h5py
|
||||
import yaml
|
||||
import json
|
||||
import rospy
|
||||
import dm_env
|
||||
import socket
|
||||
import numpy as np
|
||||
import collections
|
||||
from datetime import datetime
|
||||
from std_msgs.msg import Int32MultiArray
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||
|
||||
|
||||
class DataCollector:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir,
|
||||
dataset_name,
|
||||
max_timesteps,
|
||||
camera_names,
|
||||
state_dim,
|
||||
overwrite=False,
|
||||
):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_timesteps = max_timesteps
|
||||
self.camera_names = camera_names
|
||||
self.state_dim = state_dim
|
||||
self.overwrite = overwrite
|
||||
self.init_dict()
|
||||
self.create_dataset_dir()
|
||||
|
||||
def init_dict(self):
|
||||
self.data_dict = {
|
||||
"/observations/qpos": [],
|
||||
"/observations/qvel": [],
|
||||
"/observations/effort": [],
|
||||
"/action": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"] = []
|
||||
|
||||
def create_dataset_dir(self):
|
||||
# 按照年月日创建目录
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||
if not os.path.exists(self.dataset_dir):
|
||||
os.makedirs(self.dataset_dir)
|
||||
|
||||
def create_file(self):
|
||||
# 检查数据集名称是否存在,如果存在则递增名称
|
||||
counter = 0
|
||||
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
|
||||
if not self.overwrite:
|
||||
while os.path.exists(dataset_path + ".hdf5"):
|
||||
dataset_path = os.path.join(
|
||||
self.dataset_dir, f"{self.dataset_name}_{counter}"
|
||||
)
|
||||
counter += 1
|
||||
self.dataset_path = dataset_path
|
||||
|
||||
def collect_data(self, ts, action):
|
||||
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
|
||||
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
|
||||
self.data_dict["/observations/effort"].append(ts.observation["effort"])
|
||||
self.data_dict["/action"].append(action)
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"].append(
|
||||
ts.observation["images"][cam_name]
|
||||
)
|
||||
|
||||
def save_data(self):
|
||||
self.create_file()
|
||||
t0 = time.time()
|
||||
# 保存数据
|
||||
with h5py.File(
|
||||
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
|
||||
) as root:
|
||||
root.attrs["sim"] = False
|
||||
obs = root.create_group("observations")
|
||||
image = obs.create_group("images")
|
||||
for cam_name in self.camera_names:
|
||||
_ = image.create_dataset(
|
||||
cam_name,
|
||||
(self.max_timesteps, 480, 640, 3),
|
||||
dtype="uint8",
|
||||
chunks=(1, 480, 640, 3),
|
||||
)
|
||||
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
|
||||
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
|
||||
|
||||
for name, array in self.data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f"Saving: {time.time() - t0:.1f} secs")
|
||||
return True
|
||||
|
||||
|
||||
class DataSynchronizer:
|
||||
def __init__(self, config_path="config"):
|
||||
rospy.init_node("synchronizer", anonymous=True)
|
||||
rospy.loginfo("ROS node initialized")
|
||||
with open(config_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
self.arm_axis = config["arm_axis"]
|
||||
# 创建订阅者
|
||||
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||
self.camera_bottom_sub = Subscriber(
|
||||
config["ros_topics"]["camera_bottom"], Image
|
||||
)
|
||||
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||
|
||||
self.left_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_master_arm"], JointState
|
||||
)
|
||||
self.left_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_slave_arm"], JointState
|
||||
)
|
||||
self.right_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_master_arm"], JointState
|
||||
)
|
||||
self.right_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_slave_arm"], JointState
|
||||
)
|
||||
self.left_aloha_state_pub = rospy.Subscriber(
|
||||
config["ros_topics"]["left_aloha_state"],
|
||||
Int32MultiArray,
|
||||
self.aloha_state_callback,
|
||||
)
|
||||
|
||||
rospy.loginfo("Subscribers created")
|
||||
|
||||
# 创建同步器
|
||||
self.ats = ApproximateTimeSynchronizer(
|
||||
[
|
||||
self.camera_left_sub,
|
||||
self.camera_right_sub,
|
||||
self.camera_bottom_sub,
|
||||
self.camera_head_sub,
|
||||
self.left_master_arm_sub,
|
||||
self.left_slave_arm_sub,
|
||||
self.right_master_arm_sub,
|
||||
self.right_slave_arm_sub,
|
||||
],
|
||||
queue_size=1,
|
||||
slop=0.05,
|
||||
)
|
||||
self.ats.registerCallback(self.callback)
|
||||
rospy.loginfo("Time synchronizer created and callback registered")
|
||||
|
||||
self.data_collector = DataCollector(
|
||||
dataset_dir=config["dataset_dir"],
|
||||
dataset_name=config["dataset_name"],
|
||||
max_timesteps=config["max_timesteps"],
|
||||
camera_names=config["camera_names"],
|
||||
state_dim=config["state_dim"],
|
||||
overwrite=config["overwrite"],
|
||||
)
|
||||
self.timesteps_collected = 0
|
||||
self.begin_collect = False
|
||||
self.last_time = None
|
||||
|
||||
def callback(
|
||||
self,
|
||||
camera_left_img,
|
||||
camera_right_img,
|
||||
camera_bottom_img,
|
||||
camera_head_img,
|
||||
left_master_arm,
|
||||
left_slave_arm,
|
||||
right_master_arm,
|
||||
right_slave_arm,
|
||||
):
|
||||
if self.begin_collect:
|
||||
self.timesteps_collected += 1
|
||||
rospy.loginfo(
|
||||
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
|
||||
)
|
||||
else:
|
||||
self.timesteps_collected = 0
|
||||
return
|
||||
if self.timesteps_collected == 0:
|
||||
return
|
||||
current_time = time.time()
|
||||
if self.last_time is not None:
|
||||
frequency = 1.0 / (current_time - self.last_time)
|
||||
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
|
||||
self.last_time = current_time
|
||||
# 将ROS图像消息转换为NumPy数组
|
||||
camera_left_np_img = np.frombuffer(
|
||||
camera_left_img.data, dtype=np.uint8
|
||||
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||
camera_right_np_img = np.frombuffer(
|
||||
camera_right_img.data, dtype=np.uint8
|
||||
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||
camera_bottom_np_img = np.frombuffer(
|
||||
camera_bottom_img.data, dtype=np.uint8
|
||||
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||
camera_head_np_img = np.frombuffer(
|
||||
camera_head_img.data, dtype=np.uint8
|
||||
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||
|
||||
# 提取臂的角度,速度,力
|
||||
left_master_arm_angle = left_master_arm.position
|
||||
# left_master_arm_velocity = left_master_arm.velocity
|
||||
# left_master_arm_force = left_master_arm.effort
|
||||
|
||||
left_slave_arm_angle = left_slave_arm.position
|
||||
left_slave_arm_velocity = left_slave_arm.velocity
|
||||
left_slave_arm_force = left_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||
|
||||
right_master_arm_angle = right_master_arm.position
|
||||
# right_master_arm_velocity = right_master_arm.velocity
|
||||
# right_master_arm_force = right_master_arm.effort
|
||||
|
||||
right_slave_arm_angle = right_slave_arm.position
|
||||
right_slave_arm_velocity = right_slave_arm.velocity
|
||||
right_slave_arm_force = right_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||
|
||||
# 收集数据
|
||||
obs = collections.OrderedDict(
|
||||
{
|
||||
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||
"qvel": np.concatenate(
|
||||
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||
),
|
||||
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||
"images": {
|
||||
"cam_front": camera_head_np_img,
|
||||
"cam_low": camera_bottom_np_img,
|
||||
"cam_left": camera_left_np_img,
|
||||
"cam_right": camera_right_np_img,
|
||||
},
|
||||
}
|
||||
)
|
||||
ts = dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
|
||||
)
|
||||
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
|
||||
self.data_collector.collect_data(ts, action)
|
||||
|
||||
# 检查是否收集了足够的数据
|
||||
if self.timesteps_collected >= self.data_collector.max_timesteps:
|
||||
self.data_collector.save_data()
|
||||
|
||||
rospy.loginfo("Data collection complete")
|
||||
|
||||
self.data_collector.init_dict()
|
||||
self.begin_collect = False
|
||||
self.timesteps_collected = 0
|
||||
|
||||
def aloha_state_callback(self, data):
|
||||
if not self.begin_collect:
|
||||
self.aloha_state = data.data
|
||||
print(self.aloha_state[0], self.aloha_state[1])
|
||||
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
|
||||
self.begin_collect = True
|
||||
|
||||
|
||||
def run(self):
|
||||
rospy.loginfo("Starting ROS spin")
|
||||
|
||||
rospy.spin()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
synchronizer = DataSynchronizer(
|
||||
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
|
||||
)
|
||||
synchronizer.run()
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
<?xml version="1.0"?>
|
||||
<package format="2">
|
||||
<name>shadow_rm_aloha</name>
|
||||
<version>0.0.1</version>
|
||||
<description>The shadow_rm_aloha package</description>
|
||||
|
||||
<maintainer email="your_email@example.com">Your Name</maintainer>
|
||||
|
||||
<license>TODO</license>
|
||||
|
||||
<buildtool_depend>catkin</buildtool_depend>
|
||||
|
||||
<build_depend>rospy</build_depend>
|
||||
<build_depend>sensor_msgs</build_depend>
|
||||
<build_depend>std_msgs</build_depend>
|
||||
<build_depend>cv_bridge</build_depend>
|
||||
<build_depend>image_transport</build_depend>
|
||||
<build_depend>message_generation</build_depend>
|
||||
<build_depend>message_runtime</build_depend>
|
||||
|
||||
<exec_depend>rospy</exec_depend>
|
||||
<exec_depend>sensor_msgs</exec_depend>
|
||||
<exec_depend>std_msgs</exec_depend>
|
||||
<exec_depend>cv_bridge</exec_depend>
|
||||
<exec_depend>image_transport</exec_depend>
|
||||
<exec_depend>message_runtime</exec_depend>
|
||||
|
||||
|
||||
<export>
|
||||
</export>
|
||||
</package>
|
||||
@@ -0,0 +1,5 @@
|
||||
# GetArmStatus.srv
|
||||
|
||||
---
|
||||
sensor_msgs/JointState joint_status
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
# GetImage.srv
|
||||
---
|
||||
bool success
|
||||
sensor_msgs/Image image
|
||||
@@ -0,0 +1,4 @@
|
||||
# MoveArm.srv
|
||||
float32[] joint_angle
|
||||
---
|
||||
bool success
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
49
realman_src/realman_aloha/shadow_rm_aloha/test/mu_test.py
Normal file
49
realman_src/realman_aloha/shadow_rm_aloha/test/mu_test.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
|
||||
def collect_data(arm_id, cam_id, data_queue, lock):
|
||||
while True:
|
||||
# 模拟数据采集
|
||||
arm_data = f"Arm {arm_id} data"
|
||||
cam_data = f"Cam {cam_id} data"
|
||||
|
||||
# 获取当前时间戳
|
||||
timestamp = time.time()
|
||||
|
||||
# 将数据放入队列
|
||||
with lock:
|
||||
data_queue.put((timestamp, arm_data, cam_data))
|
||||
|
||||
# 模拟高帧率
|
||||
time.sleep(0.01)
|
||||
|
||||
def main():
|
||||
num_arms = 4
|
||||
num_cams = 4
|
||||
|
||||
# 创建队列和锁
|
||||
data_queue = mp.Queue()
|
||||
lock = mp.Lock()
|
||||
|
||||
# 创建进程
|
||||
processes = []
|
||||
for i in range(num_arms):
|
||||
p = mp.Process(target=collect_data, args=(i, i, data_queue, lock))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
# 主进程处理数据
|
||||
try:
|
||||
while True:
|
||||
if not data_queue.empty():
|
||||
with lock:
|
||||
timestamp, arm_data, cam_data = data_queue.get()
|
||||
print(f"Timestamp: {timestamp}, {arm_data}, {cam_data}")
|
||||
except KeyboardInterrupt:
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,38 @@
|
||||
import os
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from shadow_rm_aloha.data_sub_process.aloha_data_synchronizer import DataCollector
|
||||
|
||||
def test_create_dataset_dir():
|
||||
# 设置测试参数
|
||||
dataset_dir = './test_data/dataset'
|
||||
dataset_name = 'test_episode'
|
||||
max_timesteps = 100
|
||||
camera_names = ['cam1', 'cam2']
|
||||
overwrite = False
|
||||
|
||||
# 清理旧的测试数据
|
||||
if os.path.exists(dataset_dir):
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
# 创建 DataCollector 实例并调用 create_dataset_dir
|
||||
collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, overwrite)
|
||||
|
||||
# 检查目录是否按预期创建
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
expected_dir = os.path.join(dataset_dir, date_str)
|
||||
assert os.path.exists(expected_dir), f"Expected directory {expected_dir} does not exist."
|
||||
|
||||
# 检查文件名是否按预期递增
|
||||
expected_file = os.path.join(expected_dir, dataset_name + '.hdf5')
|
||||
assert collector.dataset_path == expected_file, f"Expected file path {expected_file}, but got {collector.dataset_path}"
|
||||
|
||||
# 再次调用 create_dataset_dir,检查文件名是否递增
|
||||
# collector.create_dataset_dir()
|
||||
expected_file_incremented = os.path.join(expected_dir, dataset_name + '_1.hdf5')
|
||||
assert collector.dataset_path == expected_file_incremented, f"Expected file path {expected_file_incremented}, but got {collector.dataset_path}"
|
||||
|
||||
print("All tests passed.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_create_dataset_dir()
|
||||
105
realman_src/realman_aloha/shadow_rm_aloha/test/udp_test.py
Normal file
105
realman_src/realman_aloha/shadow_rm_aloha/test/udp_test.py
Normal file
@@ -0,0 +1,105 @@
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
import random
|
||||
import socket
|
||||
import json
|
||||
import logging
|
||||
|
||||
# 设置日志级别
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class test_udp():
|
||||
def __init__(self):
|
||||
arm_ip = '192.168.1.19'
|
||||
arm_port = 8080
|
||||
self.arm =socket.socket()
|
||||
self.arm.connect((arm_ip, arm_port))
|
||||
set_udp = {"command":"set_realtime_push","cycle":1,"enable":True,"port":8090,"ip":"192.168.1.101","custom":{"aloha_state":True,"joint_speed":True,"arm_current_status":True,"hand":False, "expand_state":True}}
|
||||
self.arm.send(json.dumps(set_udp).encode('utf-8'))
|
||||
state = self.arm.recv(1024)
|
||||
|
||||
logging.info(f"Send data to {arm_ip}:{arm_port}: {state}")
|
||||
|
||||
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
||||
# 设置套接字选项,允许端口复用
|
||||
self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
local_ip = "192.168.1.101"
|
||||
local_port = 8090
|
||||
self.udp_socket.bind((local_ip, local_port))
|
||||
self.BUFFER_SIZE = 1024
|
||||
|
||||
|
||||
|
||||
def set_udp(self):
|
||||
|
||||
while True:
|
||||
start_time = time.time()
|
||||
data, addr = self.udp_socket.recvfrom(self.BUFFER_SIZE)
|
||||
# 将接收到的UDP数据解码并解析为JSON
|
||||
data = json.loads(data.decode('utf-8'))
|
||||
end_time = time.time()
|
||||
print(f"Received data {data}")
|
||||
|
||||
udp_socket.close()
|
||||
|
||||
|
||||
def collect_arm_data(arm_id, queue, event):
|
||||
while True:
|
||||
data = f"Arm {arm_id} data {random.random()}"
|
||||
queue.put((arm_id, data))
|
||||
event.set()
|
||||
time.sleep(1)
|
||||
|
||||
def collect_camera_data(camera_id, queue, event):
|
||||
while True:
|
||||
data = f"Camera {camera_id} data {random.random()}"
|
||||
queue.put((camera_id, data))
|
||||
event.set()
|
||||
time.sleep(1)
|
||||
|
||||
def main():
|
||||
arm_queues = [multiprocessing.Queue() for _ in range(4)]
|
||||
camera_queues = [multiprocessing.Queue() for _ in range(4)]
|
||||
arm_events = [multiprocessing.Event() for _ in range(4)]
|
||||
camera_events = [multiprocessing.Event() for _ in range(4)]
|
||||
|
||||
arm_processes = [multiprocessing.Process(target=collect_arm_data, args=(i, arm_queues[i], arm_events[i])) for i in range(4)]
|
||||
camera_processes = [multiprocessing.Process(target=collect_camera_data, args=(i, camera_queues[i], camera_events[i])) for i in range(4)]
|
||||
|
||||
for p in arm_processes + camera_processes:
|
||||
p.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
for event in arm_events + camera_events:
|
||||
event.wait()
|
||||
|
||||
for i in range(4):
|
||||
if not arm_queues[i].empty():
|
||||
arm_id, arm_data = arm_queues[i].get()
|
||||
print(f"Received from Arm {arm_id}: {arm_data}")
|
||||
arm_events[i].clear()
|
||||
|
||||
if not camera_queues[i].empty():
|
||||
camera_id, camera_data = camera_queues[i].get()
|
||||
print(f"Received from Camera {camera_id}: {camera_data}")
|
||||
camera_events[i].clear()
|
||||
|
||||
time.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
for p in arm_processes + camera_processes:
|
||||
p.terminate()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# test_udp = test_udp()
|
||||
# test_udp.set_udp()
|
||||
4
realman_src/realman_aloha/shadow_rm_robot/.gitignore
vendored
Normal file
4
realman_src/realman_aloha/shadow_rm_robot/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
0
realman_src/realman_aloha/shadow_rm_robot/README.md
Normal file
0
realman_src/realman_aloha/shadow_rm_robot/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
arm_ip: "192.168.1.18"
|
||||
arm_port: 8080
|
||||
arm_axis: 6
|
||||
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||
@@ -0,0 +1,5 @@
|
||||
port: /dev/ttyUSB0
|
||||
right_port: /dev/ttyUSB1
|
||||
baudrate: 460800
|
||||
hex_data: "55 AA 02 00 00 67"
|
||||
arm_axis: 6
|
||||
Binary file not shown.
Binary file not shown.
36
realman_src/realman_aloha/shadow_rm_robot/pyproject.toml
Normal file
36
realman_src/realman_aloha/shadow_rm_robot/pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_rm_robot"
|
||||
version = "0.1.0"
|
||||
description = "Robot package, including operations such as reading and controlling robots"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10"
|
||||
pyyaml = ">=6.0"
|
||||
pyserial = ">=3.5"
|
||||
pymodbus = ">=3.7"
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# cython: language_level=3
|
||||
import os
|
||||
|
||||
import logging
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
|
||||
class CommonLog(object):
|
||||
"""
|
||||
日志记录
|
||||
"""
|
||||
|
||||
def __init__(self, logger, logname='web-log'):
|
||||
self.logname = os.path.join(os.path.dirname(os.path.abspath(__file__)), '%s' % logname)
|
||||
self.logger = logger
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
self.logger.propagate = False # 禁止使用logger对象parent的处理器
|
||||
self.formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: %(message)s', '%Y-%m-%d %H:%M:%S')
|
||||
|
||||
def __console(self, level, message):
|
||||
# 创建一个FileHandler,用于写到本地
|
||||
|
||||
# fh = TimedRotatingFileHandler(self.logname, when='MIDNIGHT', interval=1, encoding='utf-8')
|
||||
# # fh = logging.FileHandler(self.logname, 'a', encoding='utf-8')
|
||||
# fh.suffix = '%Y-%m-%d.log'
|
||||
# fh.setLevel(logging.DEBUG)
|
||||
# fh.setFormatter(self.formatter)
|
||||
# self.logger.addHandler(fh)
|
||||
|
||||
# 创建一个StreamHandler,用于输出到控制台
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
ch.setFormatter(self.formatter)
|
||||
self.logger.addHandler(ch)
|
||||
|
||||
if level == 'info':
|
||||
self.logger.info(message)
|
||||
elif level == 'debug':
|
||||
self.logger.debug(message)
|
||||
elif level == 'warning':
|
||||
self.logger.warning(message)
|
||||
elif level == 'error':
|
||||
self.logger.error(message, exc_info=1) # 显示错误栈
|
||||
# self.logger.error(message)
|
||||
|
||||
elif level == 'error_':
|
||||
self.logger.error(message) # 不显示错误栈
|
||||
|
||||
|
||||
# 这两行代码是为了避免日志输出重复问题
|
||||
self.logger.removeHandler(ch)
|
||||
# self.logger.removeHandler(fh)
|
||||
# # 关闭打开的文件
|
||||
# fh.close()
|
||||
|
||||
def debug(self, message):
|
||||
self.__console('debug', message)
|
||||
|
||||
def info(self, message):
|
||||
self.__console('info', message)
|
||||
|
||||
def warning(self, message):
|
||||
self.__console('warning', message)
|
||||
|
||||
def error(self, message):
|
||||
self.__console('error', message)
|
||||
|
||||
def error_(self, message):
|
||||
self.__console('error_', message)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
import json
|
||||
import yaml
|
||||
import time
|
||||
import logging
|
||||
import socket
|
||||
import numpy as np
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class RmArm:
|
||||
def __init__(self, config_file="config.yaml"):
|
||||
"""初始化机械臂的网络连接并发送初始命令。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
"""
|
||||
self.config = self._load_config(config_file)
|
||||
self.arm_ip = self.config.get("arm_ip", "192.168.1.18")
|
||||
self.get_vel = self.config.get("get_vel", True)
|
||||
self.get_torque = self.config.get("get_torque", True)
|
||||
arm_port = self.config.get("arm_port", 8080)
|
||||
local_ip = self.config.get("local_ip", '192.168.1.101')
|
||||
local_port = self.config.get("local_port", 8089)
|
||||
|
||||
self.arm = socket.socket()
|
||||
self.arm.connect((self.arm_ip, arm_port))
|
||||
|
||||
set_udp = {"command":"set_realtime_push","cycle":6,"enable":True,"port":local_port,"ip":local_ip,"custom":{"aloha_state":True,"joint_speed":True,"arm_current_status":True,"hand":False, "expand_state":True}}
|
||||
|
||||
self.arm.send(json.dumps(set_udp).encode('utf-8'))
|
||||
_ = self.arm.recv(1024)
|
||||
|
||||
self.arm_axis = self.config.get("arm_axis", 6)
|
||||
self.arm_ki = self.config.get("arm_ki", [7, 7, 7, 3, 3, 3])
|
||||
|
||||
|
||||
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.udp_socket.bind((local_ip, local_port))
|
||||
|
||||
self.cmd_get_current_arm_state = '{"command":"get_current_arm_state"}\r\n'
|
||||
self.cmd_get_gripper_state = '{"command":"get_gripper_state"}\r\n'
|
||||
|
||||
self.cmd_set_gripper_release = (
|
||||
'{"command": "set_gripper_release", "speed": 500, "block": false}\r\n'
|
||||
)
|
||||
self.cmd_set_gripper_route = (
|
||||
'{"command":"set_gripper_route","min":0,"max":1000}\r\n'
|
||||
)
|
||||
|
||||
self.arm.send(self.cmd_set_gripper_route.encode("utf-8"))
|
||||
_ = self.arm.recv(1024)
|
||||
|
||||
self.pre_gripper_actpos = None
|
||||
self.cur_gripper_actpos = None
|
||||
self.pre_actpos_time = None
|
||||
self.cur_actpos_time = None
|
||||
|
||||
def _load_config(self, config_file):
|
||||
"""加载配置文件。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
|
||||
Returns:
|
||||
dict: 配置文件内容。
|
||||
"""
|
||||
with open(config_file, "r") as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
def _json_to_numpy(self, byte_data, key):
|
||||
"""将字节数据解析为 NumPy 数组。
|
||||
|
||||
Args:
|
||||
byte_data (bytes): 字节数据。
|
||||
key (str): JSON 数据中的键。
|
||||
|
||||
Returns:
|
||||
np.ndarray: 解析后的 NumPy 数组。
|
||||
"""
|
||||
str_data = byte_data.decode("utf-8")
|
||||
logging.debug(f"Received KEY: {key}")
|
||||
logging.debug(f"Received JSON data: {str_data}")
|
||||
try:
|
||||
data_list = json.loads(str_data)[key]
|
||||
if isinstance(data_list, dict):
|
||||
return data_list
|
||||
except KeyError:
|
||||
logging.error(f"Key '{key}' not found in JSON data")
|
||||
logging.error(f"Received JSON data: {str_data}")
|
||||
return None
|
||||
return np.array(data_list, dtype=float)
|
||||
|
||||
def set_joint_position(self, joint_angle):
|
||||
"""设置机械臂的位置。
|
||||
|
||||
Args:
|
||||
arm_pos (np.ndarray): 机械臂的位置
|
||||
|
||||
"""
|
||||
joint_angle = np.array(joint_angle)
|
||||
data = np.floor(joint_angle * 1000).astype(int).tolist()
|
||||
cmd = (
|
||||
json.dumps(
|
||||
{"command": "movej", "joint": data, "block": True, "v": 40, "r": 0}
|
||||
)
|
||||
+ "\r\n"
|
||||
)
|
||||
self.arm.send(cmd.encode("utf-8"))
|
||||
# TODO: Pending
|
||||
state = self.arm.recv(1024)
|
||||
# state = self.arm.recv(1024)
|
||||
|
||||
def set_joint_canfd_position(self, joint_angle):
|
||||
"""设置机械臂的位置。
|
||||
|
||||
Args:
|
||||
arm_pos (np.ndarray): 机械臂的位置
|
||||
"""
|
||||
joint_angle = np.array(joint_angle)
|
||||
data = np.floor(joint_angle * 1000).astype(int).tolist()
|
||||
cmd = (
|
||||
json.dumps({"command": "movej_canfd", "joint": data, "follow": False})
|
||||
+ "\r\n"
|
||||
)
|
||||
self.arm.send(cmd.encode("utf-8"))
|
||||
|
||||
def set_gripper_position(self, actpos):
|
||||
"""设置夹爪的位置。
|
||||
|
||||
Args:
|
||||
actpos (np.ndarray): 夹爪的位置,单位为毫米。
|
||||
"""
|
||||
data = np.array(actpos) * 1000
|
||||
data = np.floor(data).astype(int).tolist()
|
||||
cmd = (
|
||||
json.dumps(
|
||||
{
|
||||
"command": "set_gripper_position",
|
||||
"position": data,
|
||||
"block": False,
|
||||
}
|
||||
)
|
||||
+ "\r\n"
|
||||
)
|
||||
self.arm.send(cmd.encode("utf-8"))
|
||||
aaa = self.arm.recv(1024)
|
||||
# print(aaa)
|
||||
|
||||
def _update_state(self, gripper_actpos, actpos_time):
|
||||
"""更新关节和夹爪状态及时间。
|
||||
|
||||
Args:
|
||||
actpos_time (float): 夹爪时间戳。
|
||||
"""
|
||||
self.pre_gripper_actpos, self.cur_gripper_actpos = self.cur_gripper_actpos, gripper_actpos
|
||||
self.pre_actpos_time, self.cur_actpos_time = self.cur_actpos_time, actpos_time
|
||||
|
||||
def get_arm_data(self):
|
||||
"""获取机械臂数据"""
|
||||
data, addr = self.udp_socket.recvfrom(1024)
|
||||
data = json.loads(data.decode('utf-8'))
|
||||
# logging.info(f"Received data: {data}")
|
||||
joint_angle = np.array(data['joint_status']['joint_position']) * 0.001
|
||||
joint_velocity = np.array(data['joint_status']['joint_speed']) * 0.001 if self.get_vel else None
|
||||
joint_current = np.array(data['joint_status']['joint_current']) / 1000000 if self.get_torque else None
|
||||
joint_torque = self.current_to_torque(joint_current) if self.get_torque else None
|
||||
aloha_state = data['aloha_state']
|
||||
# logging.info(f"Time consumed: {time.time() - start_time}")
|
||||
|
||||
result = {'joint_angle': joint_angle, 'aloha_state': aloha_state}
|
||||
if self.get_vel:
|
||||
result['joint_velocity'] = joint_velocity
|
||||
if self.get_torque:
|
||||
result['joint_torque'] = joint_torque
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_gripper_data(self):
|
||||
"""获取夹爪数据"""
|
||||
try:
|
||||
actpos_time = time.time()
|
||||
self.arm.send(self.cmd_get_gripper_state.encode("utf-8"))
|
||||
# gripper_qpos = self.arm.recv(1024)
|
||||
while True:
|
||||
gripper_qpos = self.arm.recv(1024)
|
||||
data = json.loads(gripper_qpos.decode("utf-8"))
|
||||
if "actpos" in data:
|
||||
break
|
||||
else:
|
||||
self.arm.send(self.cmd_get_gripper_state.encode("utf-8"))
|
||||
gripper_actpos = self._json_to_numpy(gripper_qpos, "actpos") * 0.001
|
||||
gripper_velocity = self.get_gripper_velocity() if self.get_vel else None
|
||||
gripper_force = self._json_to_numpy(gripper_qpos, "current_force") / 100 if self.get_torque else None
|
||||
|
||||
result = {'gripper_actpos': gripper_actpos}
|
||||
if self.get_vel:
|
||||
result['gripper_velocity'] = gripper_velocity
|
||||
if self.get_torque:
|
||||
result['gripper_force'] = gripper_force
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting gripper data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_integrate_data(self):
|
||||
"""获取整合数据"""
|
||||
arm_data = self.get_arm_data()
|
||||
gripper_data = self.get_gripper_data()
|
||||
|
||||
if not arm_data or not gripper_data:
|
||||
return None
|
||||
|
||||
result = {'aloha_state': arm_data['aloha_state'], 'arm_angle': np.append(arm_data['joint_angle'], gripper_data['gripper_actpos'])}
|
||||
if self.get_vel:
|
||||
result['arm_velocity'] = np.append(arm_data['joint_velocity'], gripper_data['gripper_velocity'])
|
||||
if self.get_torque:
|
||||
result['arm_torque'] = arm_data['joint_torque'] + [gripper_data['gripper_force']]
|
||||
|
||||
return result
|
||||
|
||||
def get_gripper_velocity(self):
|
||||
"""获取夹爪速度"""
|
||||
if self.pre_actpos_time is None or self.cur_actpos_time is None:
|
||||
logging.debug("Previous or current joint positions are not available.")
|
||||
return 0
|
||||
delta_time = self.cur_actpos_time - self.pre_actpos_time
|
||||
return (self.cur_gripper_actpos - self.pre_gripper_actpos["gripper"]) / delta_time if self.cur_gripper_actpos is not None else 0
|
||||
|
||||
def current_to_torque(self, current):
|
||||
"""将电流转换为扭矩"""
|
||||
return [c * k for c, k in zip(current, self.arm_ki)]
|
||||
|
||||
def get_arm_position(self):
|
||||
"""获取机械臂的位置。
|
||||
|
||||
Returns:
|
||||
dict: 机械臂的位置,单位为毫米和弧度。
|
||||
包含以下键:
|
||||
- 'x': x 轴位置
|
||||
- 'y': y 轴位置
|
||||
- 'z': z 轴位置
|
||||
- 'roll': 滚转角
|
||||
- 'pitch': 俯仰角
|
||||
- 'yaw': 偏航角
|
||||
- 单位 : mm, rad
|
||||
"""
|
||||
self.arm.send(self.cmd_get_current_arm_state.encode("utf-8"))
|
||||
_arm_state = self.arm.recv(1024)
|
||||
arm_state = self._json_to_numpy(_arm_state, "arm_state")
|
||||
arm_pos = np.array(arm_state["pose"], dtype=float) * 0.001
|
||||
|
||||
return {
|
||||
"x": arm_pos[0],
|
||||
"y": arm_pos[1],
|
||||
"z": arm_pos[2],
|
||||
"roll": arm_pos[3],
|
||||
"pitch": arm_pos[4],
|
||||
"yaw": arm_pos[5],
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
arm_left = RmArm("/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
|
||||
# arm_right = RmArm("/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml")
|
||||
# test_left_narry = [7.235, 31.816, 51.237, 2.463, 91.054, 12.04]
|
||||
test_right_narry = [-6.155, 33.925, 62.137, -1.672, 87.892, -3.868]
|
||||
while True:
|
||||
start_time = time.time()
|
||||
arm_left.set_gripper_position(0.2)
|
||||
# left_qpos = arm_left.get_integrate_data()
|
||||
left_qpos = arm_left.get_gripper_data()
|
||||
logging.info(left_qpos)
|
||||
# right_qpos = arm_right.get_arm_data()
|
||||
# logging.info(left_qpos)
|
||||
# logging.info(right_qpos)
|
||||
# time.sleep(0.02)
|
||||
# arm_right.set_joint_canfd_position(test_right_narry)
|
||||
# arm_right.set_gripper_position(0.2)
|
||||
|
||||
|
||||
logging.info(f"Time consumed: {time.time() - start_time}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import time
|
||||
import yaml
|
||||
import serial
|
||||
import logging
|
||||
import binascii
|
||||
import numpy as np
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class ServoArm:
|
||||
def __init__(self, config_file="config.yaml"):
|
||||
"""初始化机械臂的串口连接并发送初始数据。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
"""
|
||||
self.config = self._load_config(config_file)
|
||||
self.port = self.config["port"]
|
||||
self.baudrate = self.config["baudrate"]
|
||||
self.hex_data = self.config["hex_data"]
|
||||
self.arm_axis = self.config.get("arm_axis", 7)
|
||||
|
||||
self.serial_conn = serial.Serial(self.port, self.baudrate, timeout=0)
|
||||
|
||||
self.bytes_to_send = binascii.unhexlify(self.hex_data.replace(" ", ""))
|
||||
self.serial_conn.write(self.bytes_to_send)
|
||||
time.sleep(1)
|
||||
|
||||
def _load_config(self, config_file):
|
||||
"""加载配置文件。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
|
||||
Returns:
|
||||
dict: 配置文件内容。
|
||||
"""
|
||||
with open(config_file, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
return config
|
||||
|
||||
def _bytes_to_signed_int(self, byte_data):
|
||||
"""将字节数据转换为有符号整数。
|
||||
|
||||
Args:
|
||||
byte_data (bytes): 字节数据。
|
||||
|
||||
Returns:
|
||||
int: 有符号整数。
|
||||
"""
|
||||
return int.from_bytes(byte_data, byteorder="little", signed=True)
|
||||
|
||||
def _parse_joint_data(self, hex_received):
|
||||
"""解析接收到的十六进制数据并提取关节数据。
|
||||
|
||||
Args:
|
||||
hex_received (str): 接收到的十六进制字符串数据。
|
||||
|
||||
Returns:
|
||||
dict: 解析后的关节数据。
|
||||
"""
|
||||
logging.debug(f"hex_received: {hex_received}")
|
||||
joints = {}
|
||||
for i in range(self.arm_axis):
|
||||
start = 14 + i * 10
|
||||
end = start + 8
|
||||
joint_hex = hex_received[start:end]
|
||||
joint_byte_data = bytearray.fromhex(joint_hex)
|
||||
joint_value = self._bytes_to_signed_int(joint_byte_data) / 10000.0
|
||||
joints[f"joint_{i+1}"] = joint_value
|
||||
grasp_start = 14 + self.arm_axis*10
|
||||
grasp_hex = hex_received[grasp_start:grasp_start+8]
|
||||
grasp_byte_data = bytearray.fromhex(grasp_hex)
|
||||
# 夹爪进行归一化处理
|
||||
grasp_value = self._bytes_to_signed_int(grasp_byte_data)/1000
|
||||
# print(grasp_value)
|
||||
|
||||
joints["grasp"] = grasp_value
|
||||
return joints
|
||||
|
||||
def get_joint_actions(self):
|
||||
"""从串口读取数据并解析关节动作。
|
||||
|
||||
Returns:
|
||||
dict: 包含关节数据的字典。
|
||||
"""
|
||||
self.serial_conn.write(self.bytes_to_send)
|
||||
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
|
||||
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
|
||||
actions = self._parse_joint_data(hex_received)
|
||||
return actions
|
||||
def set_gripper_action(self, action):
|
||||
"""设置夹爪动作。
|
||||
|
||||
Args:
|
||||
action (int): 夹爪动作值。
|
||||
"""
|
||||
action = int(action * 1000)
|
||||
action_bytes = action.to_bytes(4, byteorder="little", signed=True)
|
||||
self.bytes_to_send = self.bytes_to_send[:74] + action_bytes + self.bytes_to_send[78:]
|
||||
|
||||
if __name__ == "__main__":
|
||||
servo_arm = ServoArm("/home/maic/LYT/lerobot/realman_src/realman_aloha/shadow_rm_robot/config/servo_arm.yaml")
|
||||
while True:
|
||||
joint_actions = servo_arm.get_joint_actions()
|
||||
logging.info(joint_actions)
|
||||
time.sleep(1)
|
||||
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
import json
|
||||
import yaml
|
||||
import numpy as np
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
arm = RmArm("./data/rm_arm.yaml")
|
||||
|
||||
arm.arm.send( '{"command":"set_modbus_mode","port":0,"baudrate":115200,"timeout ":2}\r\n'.encode("utf-8"))
|
||||
|
||||
# arm.arm.send( '{"command":"close_modbus_mode","port":1}\r\n'.encode("utf-8"))
|
||||
|
||||
a = arm.arm.recv(1024)
|
||||
|
||||
logging.debug(a)
|
||||
|
||||
arm.arm.send( '{"command":"read_holding_registers","port":1,"address":14,"device":2}\r\n'.encode("utf-8"))
|
||||
|
||||
b = arm.arm.recv(1024)
|
||||
logging.debug(b)
|
||||
@@ -0,0 +1,90 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
import yaml
|
||||
import numpy as np
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
|
||||
class TestRmArm:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self, tmpdir):
|
||||
# 模拟配置文件
|
||||
self.config_data = {
|
||||
"arm_ip": "192.168.1.18",
|
||||
"arm_port": 8080
|
||||
}
|
||||
self.config_file = tmpdir.join("test_config.yaml")
|
||||
with open(self.config_file, "w") as file:
|
||||
yaml.dump(self.config_data, file)
|
||||
|
||||
# 初始化 RmArm 对象
|
||||
self.rm_arm = RmArm(self.config_file)
|
||||
|
||||
@patch("socket.socket")
|
||||
def test_initialization(self, mock_socket):
|
||||
# 测试初始化
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
|
||||
rm_arm = RmArm(self.config_file)
|
||||
assert rm_arm.arm_ip == self.config_data["arm_ip"]
|
||||
assert rm_arm.arm_port == self.config_data["arm_port"]
|
||||
|
||||
# 检查网络连接初始化
|
||||
mock_socket_instance.connect.assert_called_with((self.config_data["arm_ip"], self.config_data["arm_port"]))
|
||||
|
||||
def test_json_to_numpy(self):
|
||||
# 测试 JSON 数据解析为 NumPy 数组
|
||||
json_data = json.dumps({"joint": [1, 2, 3, 4, 5, 6]})
|
||||
byte_data = json_data.encode('utf-8')
|
||||
result = self.rm_arm._json_to_numpy(byte_data, 'joint')
|
||||
expected_result = np.array([1, 2, 3, 4, 5, 6], dtype=float)
|
||||
np.testing.assert_array_equal(result, expected_result)
|
||||
|
||||
# 测试键不存在的情况
|
||||
json_data = json.dumps({"other_key": [1, 2, 3]})
|
||||
byte_data = json_data.encode('utf-8')
|
||||
result = self.rm_arm._json_to_numpy(byte_data, 'joint')
|
||||
expected_result = np.array([])
|
||||
np.testing.assert_array_equal(result, expected_result)
|
||||
|
||||
def test_generate_command(self):
|
||||
# 测试生成关节命令
|
||||
data = np.array([0.1, 0.2, 0.3])
|
||||
cmd_type = 'joint'
|
||||
result = self.rm_arm._generate_command(data, cmd_type)
|
||||
expected_result = json.dumps({"command": "movej", "joint": [100, 200, 300], "v": 40, "r": 0}) + '\r\n'
|
||||
assert result == expected_result
|
||||
|
||||
# 测试生成夹爪命令
|
||||
data = np.array([500])
|
||||
cmd_type = 'gripper'
|
||||
result = self.rm_arm._generate_command(data, cmd_type)
|
||||
expected_result = json.dumps({"command": "set_gripper_position", "position": [500], "block": False}) + '\r\n'
|
||||
assert result == expected_result
|
||||
|
||||
# @patch("socket.socket")
|
||||
# def test_get_qpos(self, mock_socket):
|
||||
# # 模拟网络返回数据
|
||||
# mock_socket_instance = MagicMock()
|
||||
# mock_socket.return_value = mock_socket_instance
|
||||
# mock_socket_instance.recv.side_effect = [
|
||||
# json.dumps({"joint": [1000, 2000, 3000, 4000, 5000, 6000]}).encode('utf-8'),
|
||||
# json.dumps({"actpos": [700]}).encode('utf-8')
|
||||
# ]
|
||||
|
||||
# rm_arm = RmArm(self.config_file)
|
||||
# qpos = rm_arm.get_qpos()
|
||||
# expected_qpos = {
|
||||
# "joint_1": 1.0,
|
||||
# "joint_2": 2.0,
|
||||
# "joint_3": 3.0,
|
||||
# "joint_4": 4.0,
|
||||
# "joint_5": 5.0,
|
||||
# "joint_6": 6.0,
|
||||
# "gripper": 700.0
|
||||
# }
|
||||
# assert qpos == expected_qpos
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user